fixed Dataloader and added test cases.
This commit is contained in:
parent
24a3a7bf56
commit
fd518ff080
|
@ -14,10 +14,10 @@ class FilePathLoader:
|
|||
self.label_column = label_column
|
||||
self.successful_count = 0
|
||||
self.skipped_count = 0
|
||||
|
||||
if self.file_path_column not in dataset.column_names:
|
||||
|
||||
if self.file_path_column not in dataset.columns:
|
||||
raise ValueError(f"Column '{self.file_path_column}' not found in dataset.")
|
||||
|
||||
|
||||
def _get_image(self, item):
|
||||
try:
|
||||
path = item[self.file_path_column]
|
||||
|
@ -32,7 +32,7 @@ class FilePathLoader:
|
|||
except Exception as e:
|
||||
print(f"Error loading image from {path}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_item(self, idx):
|
||||
item = self.dataset.iloc[idx]
|
||||
image = self._get_image(item)
|
||||
|
@ -46,7 +46,7 @@ class FilePathLoader:
|
|||
else:
|
||||
self.skipped_count += 1
|
||||
return None
|
||||
|
||||
|
||||
def print_summary(self):
|
||||
print(f"Successfully converted {self.successful_count} images.")
|
||||
print(f"Skipped {self.skipped_count} images due to errors.")
|
||||
|
@ -58,14 +58,14 @@ class JPGImageLoader:
|
|||
self.label_column = label_column
|
||||
self.successful_count = 0
|
||||
self.skipped_count = 0
|
||||
|
||||
|
||||
if self.bytes_column not in dataset.columns:
|
||||
raise ValueError(f"Column '{self.bytes_column}' not found in dataset.")
|
||||
|
||||
def _get_image(self, item):
|
||||
try:
|
||||
data = item[self.bytes_column]
|
||||
|
||||
|
||||
if isinstance(data, str) and data.startswith("b'"):
|
||||
cleaned_data = data[2:-1].encode('latin1').decode('unicode-escape').encode('latin1')
|
||||
bytes_data = cleaned_data
|
||||
|
@ -73,7 +73,7 @@ class JPGImageLoader:
|
|||
bytes_data = base64.b64decode(data)
|
||||
else:
|
||||
bytes_data = data
|
||||
|
||||
|
||||
img_bytes = io.BytesIO(bytes_data)
|
||||
image = Image.open(img_bytes)
|
||||
if image.mode == 'RGBA':
|
||||
|
@ -86,7 +86,7 @@ class JPGImageLoader:
|
|||
except Exception as e:
|
||||
print(f"Error loading image from bytes: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_item(self, idx):
|
||||
item = self.dataset.iloc[idx]
|
||||
image = self._get_image(item)
|
||||
|
@ -100,37 +100,41 @@ class JPGImageLoader:
|
|||
else:
|
||||
self.skipped_count += 1
|
||||
return None
|
||||
|
||||
|
||||
def print_summary(self):
|
||||
print(f"Successfully converted {self.successful_count} images.")
|
||||
print(f"Skipped {self.skipped_count} images due to errors.")
|
||||
|
||||
class AIIADataLoader:
|
||||
def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path", label_column=None, pretraining=False, **dataloader_kwargs):
|
||||
def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path",
|
||||
label_column=None, pretraining=False, **dataloader_kwargs):
|
||||
if column not in dataset.columns:
|
||||
raise ValueError(f"Column '{column}' not found in dataset")
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.val_split = val_split
|
||||
self.seed = seed
|
||||
self.pretraining = pretraining
|
||||
random.seed(seed)
|
||||
|
||||
|
||||
sample_value = dataset[column].iloc[0]
|
||||
is_bytes_or_bytestring = isinstance(sample_value, (bytes, str)) and (
|
||||
isinstance(sample_value, bytes) or
|
||||
sample_value.startswith("b'") or
|
||||
isinstance(sample_value, bytes) or
|
||||
sample_value.startswith("b'") or
|
||||
sample_value.startswith(('b"', 'data:image'))
|
||||
)
|
||||
|
||||
|
||||
if is_bytes_or_bytestring:
|
||||
self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column)
|
||||
else:
|
||||
sample_paths = dataset[column].dropna().head(1).astype(str)
|
||||
filepath_pattern = r'.*(?:/|\\).*\.([jJ][pP][gG]|[pP][nN][gG]|[gG][iI][fF])$'
|
||||
|
||||
|
||||
if any(re.match(filepath_pattern, path, flags=re.IGNORECASE) for path in sample_paths):
|
||||
self.loader = FilePathLoader(dataset, file_path_column=column, label_column=label_column)
|
||||
else:
|
||||
self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column)
|
||||
|
||||
|
||||
self.items = []
|
||||
for idx in range(len(dataset)):
|
||||
item = self.loader.get_item(idx)
|
||||
|
@ -141,33 +145,32 @@ class AIIADataLoader:
|
|||
self.items.append((img, 'rotate', 0))
|
||||
else:
|
||||
self.items.append(item)
|
||||
|
||||
|
||||
if not self.items:
|
||||
raise ValueError("No valid items were loaded from the dataset")
|
||||
|
||||
|
||||
train_indices, val_indices = self._split_data()
|
||||
|
||||
|
||||
self.train_dataset = self._create_subset(train_indices)
|
||||
self.val_dataset = self._create_subset(val_indices)
|
||||
|
||||
|
||||
self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True, **dataloader_kwargs)
|
||||
self.val_loader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False, **dataloader_kwargs)
|
||||
|
||||
|
||||
def _split_data(self):
|
||||
if len(self.items) == 0:
|
||||
raise ValueError("No items to split")
|
||||
|
||||
|
||||
num_samples = len(self.items)
|
||||
indices = list(range(num_samples))
|
||||
random.shuffle(indices)
|
||||
|
||||
|
||||
split_idx = int((1 - self.val_split) * num_samples)
|
||||
train_indices = indices[:split_idx]
|
||||
val_indices = indices[split_idx:]
|
||||
|
||||
|
||||
return train_indices, val_indices
|
||||
|
||||
|
||||
def _create_subset(self, indices):
|
||||
subset_items = [self.items[i] for i in indices]
|
||||
return AIIADataset(subset_items, pretraining=self.pretraining)
|
||||
|
@ -180,22 +183,24 @@ class AIIADataset(torch.utils.data.Dataset):
|
|||
transforms.Resize((224, 224)),
|
||||
transforms.ToTensor()
|
||||
])
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.items)
|
||||
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.items[idx]
|
||||
|
||||
|
||||
if self.pretraining:
|
||||
image, task, label = item
|
||||
if not isinstance(image, Image.Image):
|
||||
raise ValueError(f"Invalid image at index {idx}")
|
||||
|
||||
|
||||
# Check image dimensions before transform
|
||||
if image.size[0] < 224 or image.size[1] < 224:
|
||||
raise ValueError("Invalid image dimensions")
|
||||
|
||||
image = self.transform(image)
|
||||
if image.shape != (3, 224, 224):
|
||||
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
|
||||
|
||||
|
||||
if task == 'denoise':
|
||||
noise_std = 0.1
|
||||
noisy_img = image + torch.randn_like(image) * noise_std
|
||||
|
@ -214,15 +219,20 @@ class AIIADataset(torch.utils.data.Dataset):
|
|||
image, label = item
|
||||
if not isinstance(image, Image.Image):
|
||||
raise ValueError(f"Invalid image at index {idx}")
|
||||
|
||||
# Check image dimensions before transform
|
||||
if image.size[0] < 224 or image.size[1] < 224:
|
||||
raise ValueError("Invalid image dimensions")
|
||||
|
||||
image = self.transform(image)
|
||||
if image.shape != (3, 224, 224):
|
||||
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
|
||||
return image, label
|
||||
else:
|
||||
if isinstance(item, Image.Image):
|
||||
image = self.transform(item)
|
||||
else:
|
||||
image = self.transform(item[0])
|
||||
if image.shape != (3, 224, 224):
|
||||
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
|
||||
image = item[0] if isinstance(item, tuple) else item
|
||||
if not isinstance(image, Image.Image):
|
||||
raise ValueError(f"Invalid image at index {idx}")
|
||||
|
||||
# Check image dimensions before transform
|
||||
if image.size[0] < 224 or image.size[1] < 224:
|
||||
raise ValueError("Invalid image dimensions")
|
||||
image = self.transform(image)
|
||||
return image
|
||||
|
|
|
@ -0,0 +1,112 @@
|
|||
import pytest
|
||||
from PIL import Image
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import transforms
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from aiia.data.DataLoader import FilePathLoader, JPGImageLoader, AIIADataLoader, AIIADataset
|
||||
|
||||
def create_sample_dataset(file_paths=None):
|
||||
if file_paths is None:
|
||||
file_paths = ['path/to/image1.jpg', 'path/to/image2.png']
|
||||
|
||||
data = {
|
||||
'file_path': file_paths,
|
||||
'label': [0] * len(file_paths) # Match length of labels to file_paths
|
||||
}
|
||||
df = pd.DataFrame(data)
|
||||
return df
|
||||
|
||||
def create_sample_bytes_dataset(bytes_data=None):
|
||||
if bytes_data is None:
|
||||
bytes_data = [b'fake_image_data_1', b'fake_image_data_2']
|
||||
|
||||
data = {
|
||||
'jpg': bytes_data,
|
||||
'label': [0] * len(bytes_data) # Match length of labels to bytes_data
|
||||
}
|
||||
df = pd.DataFrame(data)
|
||||
return df
|
||||
|
||||
def test_file_path_loader(mocker):
|
||||
# Mock Image.open to return a fake image
|
||||
mock_image = Image.new('RGB', (224, 224))
|
||||
mocker.patch('PIL.Image.open', return_value=mock_image)
|
||||
|
||||
dataset = create_sample_dataset()
|
||||
loader = FilePathLoader(dataset, label_column='label') # Added label_column
|
||||
item = loader.get_item(0)
|
||||
assert isinstance(item[0], Image.Image)
|
||||
assert item[1] == 0
|
||||
|
||||
loader.print_summary()
|
||||
|
||||
def test_jpg_image_loader(mocker):
|
||||
# Mock Image.open to return a fake image
|
||||
mock_image = Image.new('RGB', (224, 224))
|
||||
mocker.patch('PIL.Image.open', return_value=mock_image)
|
||||
|
||||
dataset = create_sample_bytes_dataset()
|
||||
loader = JPGImageLoader(dataset, label_column='label') # Added label_column
|
||||
item = loader.get_item(0)
|
||||
assert isinstance(item[0], Image.Image)
|
||||
assert item[1] == 0
|
||||
|
||||
loader.print_summary()
|
||||
|
||||
def test_aiia_data_loader(mocker):
|
||||
# Mock Image.open to return a fake image
|
||||
mock_image = Image.new('RGB', (224, 224))
|
||||
mocker.patch('PIL.Image.open', return_value=mock_image)
|
||||
|
||||
dataset = create_sample_dataset()
|
||||
data_loader = AIIADataLoader(dataset, batch_size=2, label_column='label')
|
||||
|
||||
# Test train loader
|
||||
batch = next(iter(data_loader.train_loader))
|
||||
assert isinstance(batch, list)
|
||||
assert len(batch) == 2 # (images, labels)
|
||||
assert batch[0].shape[0] == 1 # batch size
|
||||
|
||||
def test_aiia_dataset():
|
||||
items = [(Image.new('RGB', (224, 224)), 0), (Image.new('RGB', (224, 224)), 1)]
|
||||
dataset = AIIADataset(items)
|
||||
|
||||
assert len(dataset) == 2
|
||||
|
||||
item = dataset[0]
|
||||
assert isinstance(item[0], torch.Tensor)
|
||||
assert item[1] == 0
|
||||
|
||||
def test_aiia_dataset_pre_training():
|
||||
items = [(Image.new('RGB', (224, 224)), 'denoise', Image.new('RGB', (224, 224)))]
|
||||
dataset = AIIADataset(items, pretraining=True)
|
||||
|
||||
assert len(dataset) == 1
|
||||
|
||||
item = dataset[0]
|
||||
assert isinstance(item[0], torch.Tensor)
|
||||
assert isinstance(item[2], str)
|
||||
|
||||
def test_aiia_dataset_invalid_image():
|
||||
items = [(Image.new('RGB', (50, 50)), 0)] # Create small image
|
||||
dataset = AIIADataset(items)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid image dimensions"):
|
||||
dataset[0]
|
||||
|
||||
def test_aiia_dataset_invalid_task():
|
||||
items = [(Image.new('RGB', (224, 224)), 'invalid_task', Image.new('RGB', (224, 224)))]
|
||||
dataset = AIIADataset(items, pretraining=True)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
dataset[0]
|
||||
|
||||
def test_aiia_data_loader_invalid_column():
|
||||
dataset = create_sample_dataset()
|
||||
with pytest.raises(ValueError, match="Column 'invalid_column' not found"):
|
||||
AIIADataLoader(dataset, column='invalid_column')
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main(['-v'])
|
Loading…
Reference in New Issue