112 lines
3.5 KiB
Python
112 lines
3.5 KiB
Python
import pytest
|
|
from PIL import Image
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
from torchvision import transforms
|
|
import pandas as pd
|
|
import numpy as np
|
|
from aiia.data.DataLoader import FilePathLoader, JPGImageLoader, AIIADataLoader, AIIADataset
|
|
|
|
def create_sample_dataset(file_paths=None):
|
|
if file_paths is None:
|
|
file_paths = ['path/to/image1.jpg', 'path/to/image2.png']
|
|
|
|
data = {
|
|
'file_path': file_paths,
|
|
'label': [0] * len(file_paths) # Match length of labels to file_paths
|
|
}
|
|
df = pd.DataFrame(data)
|
|
return df
|
|
|
|
def create_sample_bytes_dataset(bytes_data=None):
|
|
if bytes_data is None:
|
|
bytes_data = [b'fake_image_data_1', b'fake_image_data_2']
|
|
|
|
data = {
|
|
'jpg': bytes_data,
|
|
'label': [0] * len(bytes_data) # Match length of labels to bytes_data
|
|
}
|
|
df = pd.DataFrame(data)
|
|
return df
|
|
|
|
def test_file_path_loader(mocker):
|
|
# Mock Image.open to return a fake image
|
|
mock_image = Image.new('RGB', (224, 224))
|
|
mocker.patch('PIL.Image.open', return_value=mock_image)
|
|
|
|
dataset = create_sample_dataset()
|
|
loader = FilePathLoader(dataset, label_column='label') # Added label_column
|
|
item = loader.get_item(0)
|
|
assert isinstance(item[0], Image.Image)
|
|
assert item[1] == 0
|
|
|
|
loader.print_summary()
|
|
|
|
def test_jpg_image_loader(mocker):
|
|
# Mock Image.open to return a fake image
|
|
mock_image = Image.new('RGB', (224, 224))
|
|
mocker.patch('PIL.Image.open', return_value=mock_image)
|
|
|
|
dataset = create_sample_bytes_dataset()
|
|
loader = JPGImageLoader(dataset, label_column='label') # Added label_column
|
|
item = loader.get_item(0)
|
|
assert isinstance(item[0], Image.Image)
|
|
assert item[1] == 0
|
|
|
|
loader.print_summary()
|
|
|
|
def test_aiia_data_loader(mocker):
|
|
# Mock Image.open to return a fake image
|
|
mock_image = Image.new('RGB', (224, 224))
|
|
mocker.patch('PIL.Image.open', return_value=mock_image)
|
|
|
|
dataset = create_sample_dataset()
|
|
data_loader = AIIADataLoader(dataset, batch_size=2, label_column='label')
|
|
|
|
# Test train loader
|
|
batch = next(iter(data_loader.train_loader))
|
|
assert isinstance(batch, list)
|
|
assert len(batch) == 2 # (images, labels)
|
|
assert batch[0].shape[0] == 1 # batch size
|
|
|
|
def test_aiia_dataset():
|
|
items = [(Image.new('RGB', (224, 224)), 0), (Image.new('RGB', (224, 224)), 1)]
|
|
dataset = AIIADataset(items)
|
|
|
|
assert len(dataset) == 2
|
|
|
|
item = dataset[0]
|
|
assert isinstance(item[0], torch.Tensor)
|
|
assert item[1] == 0
|
|
|
|
def test_aiia_dataset_pre_training():
|
|
items = [(Image.new('RGB', (224, 224)), 'denoise', Image.new('RGB', (224, 224)))]
|
|
dataset = AIIADataset(items, pretraining=True)
|
|
|
|
assert len(dataset) == 1
|
|
|
|
item = dataset[0]
|
|
assert isinstance(item[0], torch.Tensor)
|
|
assert isinstance(item[2], str)
|
|
|
|
def test_aiia_dataset_invalid_image():
|
|
items = [(Image.new('RGB', (50, 50)), 0)] # Create small image
|
|
dataset = AIIADataset(items)
|
|
|
|
with pytest.raises(ValueError, match="Invalid image dimensions"):
|
|
dataset[0]
|
|
|
|
def test_aiia_dataset_invalid_task():
|
|
items = [(Image.new('RGB', (224, 224)), 'invalid_task', Image.new('RGB', (224, 224)))]
|
|
dataset = AIIADataset(items, pretraining=True)
|
|
|
|
with pytest.raises(ValueError):
|
|
dataset[0]
|
|
|
|
def test_aiia_data_loader_invalid_column():
|
|
dataset = create_sample_dataset()
|
|
with pytest.raises(ValueError, match="Column 'invalid_column' not found"):
|
|
AIIADataLoader(dataset, column='invalid_column')
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main(['-v']) |