fixed Dataloader and added test cases.
This commit is contained in:
parent
24a3a7bf56
commit
fd518ff080
|
@ -15,7 +15,7 @@ class FilePathLoader:
|
||||||
self.successful_count = 0
|
self.successful_count = 0
|
||||||
self.skipped_count = 0
|
self.skipped_count = 0
|
||||||
|
|
||||||
if self.file_path_column not in dataset.column_names:
|
if self.file_path_column not in dataset.columns:
|
||||||
raise ValueError(f"Column '{self.file_path_column}' not found in dataset.")
|
raise ValueError(f"Column '{self.file_path_column}' not found in dataset.")
|
||||||
|
|
||||||
def _get_image(self, item):
|
def _get_image(self, item):
|
||||||
|
@ -106,7 +106,11 @@ class JPGImageLoader:
|
||||||
print(f"Skipped {self.skipped_count} images due to errors.")
|
print(f"Skipped {self.skipped_count} images due to errors.")
|
||||||
|
|
||||||
class AIIADataLoader:
|
class AIIADataLoader:
|
||||||
def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path", label_column=None, pretraining=False, **dataloader_kwargs):
|
def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path",
|
||||||
|
label_column=None, pretraining=False, **dataloader_kwargs):
|
||||||
|
if column not in dataset.columns:
|
||||||
|
raise ValueError(f"Column '{column}' not found in dataset")
|
||||||
|
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.val_split = val_split
|
self.val_split = val_split
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
|
@ -145,7 +149,6 @@ class AIIADataLoader:
|
||||||
if not self.items:
|
if not self.items:
|
||||||
raise ValueError("No valid items were loaded from the dataset")
|
raise ValueError("No valid items were loaded from the dataset")
|
||||||
|
|
||||||
|
|
||||||
train_indices, val_indices = self._split_data()
|
train_indices, val_indices = self._split_data()
|
||||||
|
|
||||||
self.train_dataset = self._create_subset(train_indices)
|
self.train_dataset = self._create_subset(train_indices)
|
||||||
|
@ -192,9 +195,11 @@ class AIIADataset(torch.utils.data.Dataset):
|
||||||
if not isinstance(image, Image.Image):
|
if not isinstance(image, Image.Image):
|
||||||
raise ValueError(f"Invalid image at index {idx}")
|
raise ValueError(f"Invalid image at index {idx}")
|
||||||
|
|
||||||
|
# Check image dimensions before transform
|
||||||
|
if image.size[0] < 224 or image.size[1] < 224:
|
||||||
|
raise ValueError("Invalid image dimensions")
|
||||||
|
|
||||||
image = self.transform(image)
|
image = self.transform(image)
|
||||||
if image.shape != (3, 224, 224):
|
|
||||||
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
|
|
||||||
|
|
||||||
if task == 'denoise':
|
if task == 'denoise':
|
||||||
noise_std = 0.1
|
noise_std = 0.1
|
||||||
|
@ -214,15 +219,20 @@ class AIIADataset(torch.utils.data.Dataset):
|
||||||
image, label = item
|
image, label = item
|
||||||
if not isinstance(image, Image.Image):
|
if not isinstance(image, Image.Image):
|
||||||
raise ValueError(f"Invalid image at index {idx}")
|
raise ValueError(f"Invalid image at index {idx}")
|
||||||
|
|
||||||
|
# Check image dimensions before transform
|
||||||
|
if image.size[0] < 224 or image.size[1] < 224:
|
||||||
|
raise ValueError("Invalid image dimensions")
|
||||||
|
|
||||||
image = self.transform(image)
|
image = self.transform(image)
|
||||||
if image.shape != (3, 224, 224):
|
|
||||||
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
|
|
||||||
return image, label
|
return image, label
|
||||||
else:
|
else:
|
||||||
if isinstance(item, Image.Image):
|
image = item[0] if isinstance(item, tuple) else item
|
||||||
image = self.transform(item)
|
if not isinstance(image, Image.Image):
|
||||||
else:
|
raise ValueError(f"Invalid image at index {idx}")
|
||||||
image = self.transform(item[0])
|
|
||||||
if image.shape != (3, 224, 224):
|
# Check image dimensions before transform
|
||||||
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
|
if image.size[0] < 224 or image.size[1] < 224:
|
||||||
|
raise ValueError("Invalid image dimensions")
|
||||||
|
image = self.transform(image)
|
||||||
return image
|
return image
|
||||||
|
|
|
@ -0,0 +1,112 @@
|
||||||
|
import pytest
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torchvision import transforms
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from aiia.data.DataLoader import FilePathLoader, JPGImageLoader, AIIADataLoader, AIIADataset
|
||||||
|
|
||||||
|
def create_sample_dataset(file_paths=None):
|
||||||
|
if file_paths is None:
|
||||||
|
file_paths = ['path/to/image1.jpg', 'path/to/image2.png']
|
||||||
|
|
||||||
|
data = {
|
||||||
|
'file_path': file_paths,
|
||||||
|
'label': [0] * len(file_paths) # Match length of labels to file_paths
|
||||||
|
}
|
||||||
|
df = pd.DataFrame(data)
|
||||||
|
return df
|
||||||
|
|
||||||
|
def create_sample_bytes_dataset(bytes_data=None):
|
||||||
|
if bytes_data is None:
|
||||||
|
bytes_data = [b'fake_image_data_1', b'fake_image_data_2']
|
||||||
|
|
||||||
|
data = {
|
||||||
|
'jpg': bytes_data,
|
||||||
|
'label': [0] * len(bytes_data) # Match length of labels to bytes_data
|
||||||
|
}
|
||||||
|
df = pd.DataFrame(data)
|
||||||
|
return df
|
||||||
|
|
||||||
|
def test_file_path_loader(mocker):
|
||||||
|
# Mock Image.open to return a fake image
|
||||||
|
mock_image = Image.new('RGB', (224, 224))
|
||||||
|
mocker.patch('PIL.Image.open', return_value=mock_image)
|
||||||
|
|
||||||
|
dataset = create_sample_dataset()
|
||||||
|
loader = FilePathLoader(dataset, label_column='label') # Added label_column
|
||||||
|
item = loader.get_item(0)
|
||||||
|
assert isinstance(item[0], Image.Image)
|
||||||
|
assert item[1] == 0
|
||||||
|
|
||||||
|
loader.print_summary()
|
||||||
|
|
||||||
|
def test_jpg_image_loader(mocker):
|
||||||
|
# Mock Image.open to return a fake image
|
||||||
|
mock_image = Image.new('RGB', (224, 224))
|
||||||
|
mocker.patch('PIL.Image.open', return_value=mock_image)
|
||||||
|
|
||||||
|
dataset = create_sample_bytes_dataset()
|
||||||
|
loader = JPGImageLoader(dataset, label_column='label') # Added label_column
|
||||||
|
item = loader.get_item(0)
|
||||||
|
assert isinstance(item[0], Image.Image)
|
||||||
|
assert item[1] == 0
|
||||||
|
|
||||||
|
loader.print_summary()
|
||||||
|
|
||||||
|
def test_aiia_data_loader(mocker):
|
||||||
|
# Mock Image.open to return a fake image
|
||||||
|
mock_image = Image.new('RGB', (224, 224))
|
||||||
|
mocker.patch('PIL.Image.open', return_value=mock_image)
|
||||||
|
|
||||||
|
dataset = create_sample_dataset()
|
||||||
|
data_loader = AIIADataLoader(dataset, batch_size=2, label_column='label')
|
||||||
|
|
||||||
|
# Test train loader
|
||||||
|
batch = next(iter(data_loader.train_loader))
|
||||||
|
assert isinstance(batch, list)
|
||||||
|
assert len(batch) == 2 # (images, labels)
|
||||||
|
assert batch[0].shape[0] == 1 # batch size
|
||||||
|
|
||||||
|
def test_aiia_dataset():
|
||||||
|
items = [(Image.new('RGB', (224, 224)), 0), (Image.new('RGB', (224, 224)), 1)]
|
||||||
|
dataset = AIIADataset(items)
|
||||||
|
|
||||||
|
assert len(dataset) == 2
|
||||||
|
|
||||||
|
item = dataset[0]
|
||||||
|
assert isinstance(item[0], torch.Tensor)
|
||||||
|
assert item[1] == 0
|
||||||
|
|
||||||
|
def test_aiia_dataset_pre_training():
|
||||||
|
items = [(Image.new('RGB', (224, 224)), 'denoise', Image.new('RGB', (224, 224)))]
|
||||||
|
dataset = AIIADataset(items, pretraining=True)
|
||||||
|
|
||||||
|
assert len(dataset) == 1
|
||||||
|
|
||||||
|
item = dataset[0]
|
||||||
|
assert isinstance(item[0], torch.Tensor)
|
||||||
|
assert isinstance(item[2], str)
|
||||||
|
|
||||||
|
def test_aiia_dataset_invalid_image():
|
||||||
|
items = [(Image.new('RGB', (50, 50)), 0)] # Create small image
|
||||||
|
dataset = AIIADataset(items)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Invalid image dimensions"):
|
||||||
|
dataset[0]
|
||||||
|
|
||||||
|
def test_aiia_dataset_invalid_task():
|
||||||
|
items = [(Image.new('RGB', (224, 224)), 'invalid_task', Image.new('RGB', (224, 224)))]
|
||||||
|
dataset = AIIADataset(items, pretraining=True)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
dataset[0]
|
||||||
|
|
||||||
|
def test_aiia_data_loader_invalid_column():
|
||||||
|
dataset = create_sample_dataset()
|
||||||
|
with pytest.raises(ValueError, match="Column 'invalid_column' not found"):
|
||||||
|
AIIADataLoader(dataset, column='invalid_column')
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main(['-v'])
|
Loading…
Reference in New Issue