AIIA/tests/data/test_DataLoader.py

112 lines
3.5 KiB
Python

import pytest
from PIL import Image
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import pandas as pd
import numpy as np
from aiia.data.DataLoader import FilePathLoader, JPGImageLoader, AIIADataLoader, AIIADataset
def create_sample_dataset(file_paths=None):
if file_paths is None:
file_paths = ['path/to/image1.jpg', 'path/to/image2.png']
data = {
'file_path': file_paths,
'label': [0] * len(file_paths) # Match length of labels to file_paths
}
df = pd.DataFrame(data)
return df
def create_sample_bytes_dataset(bytes_data=None):
if bytes_data is None:
bytes_data = [b'fake_image_data_1', b'fake_image_data_2']
data = {
'jpg': bytes_data,
'label': [0] * len(bytes_data) # Match length of labels to bytes_data
}
df = pd.DataFrame(data)
return df
def test_file_path_loader(mocker):
# Mock Image.open to return a fake image
mock_image = Image.new('RGB', (224, 224))
mocker.patch('PIL.Image.open', return_value=mock_image)
dataset = create_sample_dataset()
loader = FilePathLoader(dataset, label_column='label') # Added label_column
item = loader.get_item(0)
assert isinstance(item[0], Image.Image)
assert item[1] == 0
loader.print_summary()
def test_jpg_image_loader(mocker):
# Mock Image.open to return a fake image
mock_image = Image.new('RGB', (224, 224))
mocker.patch('PIL.Image.open', return_value=mock_image)
dataset = create_sample_bytes_dataset()
loader = JPGImageLoader(dataset, label_column='label') # Added label_column
item = loader.get_item(0)
assert isinstance(item[0], Image.Image)
assert item[1] == 0
loader.print_summary()
def test_aiia_data_loader(mocker):
# Mock Image.open to return a fake image
mock_image = Image.new('RGB', (224, 224))
mocker.patch('PIL.Image.open', return_value=mock_image)
dataset = create_sample_dataset()
data_loader = AIIADataLoader(dataset, batch_size=2, label_column='label')
# Test train loader
batch = next(iter(data_loader.train_loader))
assert isinstance(batch, list)
assert len(batch) == 2 # (images, labels)
assert batch[0].shape[0] == 1 # batch size
def test_aiia_dataset():
items = [(Image.new('RGB', (224, 224)), 0), (Image.new('RGB', (224, 224)), 1)]
dataset = AIIADataset(items)
assert len(dataset) == 2
item = dataset[0]
assert isinstance(item[0], torch.Tensor)
assert item[1] == 0
def test_aiia_dataset_pre_training():
items = [(Image.new('RGB', (224, 224)), 'denoise', Image.new('RGB', (224, 224)))]
dataset = AIIADataset(items, pretraining=True)
assert len(dataset) == 1
item = dataset[0]
assert isinstance(item[0], torch.Tensor)
assert isinstance(item[2], str)
def test_aiia_dataset_invalid_image():
items = [(Image.new('RGB', (50, 50)), 0)] # Create small image
dataset = AIIADataset(items)
with pytest.raises(ValueError, match="Invalid image dimensions"):
dataset[0]
def test_aiia_dataset_invalid_task():
items = [(Image.new('RGB', (224, 224)), 'invalid_task', Image.new('RGB', (224, 224)))]
dataset = AIIADataset(items, pretraining=True)
with pytest.raises(ValueError):
dataset[0]
def test_aiia_data_loader_invalid_column():
dataset = create_sample_dataset()
with pytest.raises(ValueError, match="Column 'invalid_column' not found"):
AIIADataLoader(dataset, column='invalid_column')
if __name__ == "__main__":
pytest.main(['-v'])