diff --git a/src/aiia/data/DataLoader.py b/src/aiia/data/DataLoader.py index 4ba5032..6710b61 100644 --- a/src/aiia/data/DataLoader.py +++ b/src/aiia/data/DataLoader.py @@ -14,10 +14,10 @@ class FilePathLoader: self.label_column = label_column self.successful_count = 0 self.skipped_count = 0 - - if self.file_path_column not in dataset.column_names: + + if self.file_path_column not in dataset.columns: raise ValueError(f"Column '{self.file_path_column}' not found in dataset.") - + def _get_image(self, item): try: path = item[self.file_path_column] @@ -32,7 +32,7 @@ class FilePathLoader: except Exception as e: print(f"Error loading image from {path}: {e}") return None - + def get_item(self, idx): item = self.dataset.iloc[idx] image = self._get_image(item) @@ -46,7 +46,7 @@ class FilePathLoader: else: self.skipped_count += 1 return None - + def print_summary(self): print(f"Successfully converted {self.successful_count} images.") print(f"Skipped {self.skipped_count} images due to errors.") @@ -58,14 +58,14 @@ class JPGImageLoader: self.label_column = label_column self.successful_count = 0 self.skipped_count = 0 - + if self.bytes_column not in dataset.columns: raise ValueError(f"Column '{self.bytes_column}' not found in dataset.") def _get_image(self, item): try: data = item[self.bytes_column] - + if isinstance(data, str) and data.startswith("b'"): cleaned_data = data[2:-1].encode('latin1').decode('unicode-escape').encode('latin1') bytes_data = cleaned_data @@ -73,7 +73,7 @@ class JPGImageLoader: bytes_data = base64.b64decode(data) else: bytes_data = data - + img_bytes = io.BytesIO(bytes_data) image = Image.open(img_bytes) if image.mode == 'RGBA': @@ -86,7 +86,7 @@ class JPGImageLoader: except Exception as e: print(f"Error loading image from bytes: {e}") return None - + def get_item(self, idx): item = self.dataset.iloc[idx] image = self._get_image(item) @@ -100,37 +100,41 @@ class JPGImageLoader: else: self.skipped_count += 1 return None - + def print_summary(self): print(f"Successfully converted {self.successful_count} images.") print(f"Skipped {self.skipped_count} images due to errors.") class AIIADataLoader: - def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path", label_column=None, pretraining=False, **dataloader_kwargs): + def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path", + label_column=None, pretraining=False, **dataloader_kwargs): + if column not in dataset.columns: + raise ValueError(f"Column '{column}' not found in dataset") + self.batch_size = batch_size self.val_split = val_split self.seed = seed self.pretraining = pretraining random.seed(seed) - + sample_value = dataset[column].iloc[0] is_bytes_or_bytestring = isinstance(sample_value, (bytes, str)) and ( - isinstance(sample_value, bytes) or - sample_value.startswith("b'") or + isinstance(sample_value, bytes) or + sample_value.startswith("b'") or sample_value.startswith(('b"', 'data:image')) ) - + if is_bytes_or_bytestring: self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column) else: sample_paths = dataset[column].dropna().head(1).astype(str) filepath_pattern = r'.*(?:/|\\).*\.([jJ][pP][gG]|[pP][nN][gG]|[gG][iI][fF])$' - + if any(re.match(filepath_pattern, path, flags=re.IGNORECASE) for path in sample_paths): self.loader = FilePathLoader(dataset, file_path_column=column, label_column=label_column) else: self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column) - + self.items = [] for idx in range(len(dataset)): item = self.loader.get_item(idx) @@ -141,33 +145,32 @@ class AIIADataLoader: self.items.append((img, 'rotate', 0)) else: self.items.append(item) - + if not self.items: raise ValueError("No valid items were loaded from the dataset") - train_indices, val_indices = self._split_data() - + self.train_dataset = self._create_subset(train_indices) self.val_dataset = self._create_subset(val_indices) - + self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True, **dataloader_kwargs) self.val_loader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False, **dataloader_kwargs) - + def _split_data(self): if len(self.items) == 0: raise ValueError("No items to split") - + num_samples = len(self.items) indices = list(range(num_samples)) random.shuffle(indices) - + split_idx = int((1 - self.val_split) * num_samples) train_indices = indices[:split_idx] val_indices = indices[split_idx:] - + return train_indices, val_indices - + def _create_subset(self, indices): subset_items = [self.items[i] for i in indices] return AIIADataset(subset_items, pretraining=self.pretraining) @@ -180,22 +183,24 @@ class AIIADataset(torch.utils.data.Dataset): transforms.Resize((224, 224)), transforms.ToTensor() ]) - + def __len__(self): return len(self.items) - + def __getitem__(self, idx): item = self.items[idx] - + if self.pretraining: image, task, label = item if not isinstance(image, Image.Image): raise ValueError(f"Invalid image at index {idx}") - + + # Check image dimensions before transform + if image.size[0] < 224 or image.size[1] < 224: + raise ValueError("Invalid image dimensions") + image = self.transform(image) - if image.shape != (3, 224, 224): - raise ValueError(f"Invalid image shape at index {idx}: {image.shape}") - + if task == 'denoise': noise_std = 0.1 noisy_img = image + torch.randn_like(image) * noise_std @@ -214,15 +219,20 @@ class AIIADataset(torch.utils.data.Dataset): image, label = item if not isinstance(image, Image.Image): raise ValueError(f"Invalid image at index {idx}") + + # Check image dimensions before transform + if image.size[0] < 224 or image.size[1] < 224: + raise ValueError("Invalid image dimensions") + image = self.transform(image) - if image.shape != (3, 224, 224): - raise ValueError(f"Invalid image shape at index {idx}: {image.shape}") return image, label else: - if isinstance(item, Image.Image): - image = self.transform(item) - else: - image = self.transform(item[0]) - if image.shape != (3, 224, 224): - raise ValueError(f"Invalid image shape at index {idx}: {image.shape}") + image = item[0] if isinstance(item, tuple) else item + if not isinstance(image, Image.Image): + raise ValueError(f"Invalid image at index {idx}") + + # Check image dimensions before transform + if image.size[0] < 224 or image.size[1] < 224: + raise ValueError("Invalid image dimensions") + image = self.transform(image) return image diff --git a/tests/data/test_DataLoader.py b/tests/data/test_DataLoader.py new file mode 100644 index 0000000..b048d68 --- /dev/null +++ b/tests/data/test_DataLoader.py @@ -0,0 +1,112 @@ +import pytest +from PIL import Image +import torch +from torch.utils.data import DataLoader +from torchvision import transforms +import pandas as pd +import numpy as np +from aiia.data.DataLoader import FilePathLoader, JPGImageLoader, AIIADataLoader, AIIADataset + +def create_sample_dataset(file_paths=None): + if file_paths is None: + file_paths = ['path/to/image1.jpg', 'path/to/image2.png'] + + data = { + 'file_path': file_paths, + 'label': [0] * len(file_paths) # Match length of labels to file_paths + } + df = pd.DataFrame(data) + return df + +def create_sample_bytes_dataset(bytes_data=None): + if bytes_data is None: + bytes_data = [b'fake_image_data_1', b'fake_image_data_2'] + + data = { + 'jpg': bytes_data, + 'label': [0] * len(bytes_data) # Match length of labels to bytes_data + } + df = pd.DataFrame(data) + return df + +def test_file_path_loader(mocker): + # Mock Image.open to return a fake image + mock_image = Image.new('RGB', (224, 224)) + mocker.patch('PIL.Image.open', return_value=mock_image) + + dataset = create_sample_dataset() + loader = FilePathLoader(dataset, label_column='label') # Added label_column + item = loader.get_item(0) + assert isinstance(item[0], Image.Image) + assert item[1] == 0 + + loader.print_summary() + +def test_jpg_image_loader(mocker): + # Mock Image.open to return a fake image + mock_image = Image.new('RGB', (224, 224)) + mocker.patch('PIL.Image.open', return_value=mock_image) + + dataset = create_sample_bytes_dataset() + loader = JPGImageLoader(dataset, label_column='label') # Added label_column + item = loader.get_item(0) + assert isinstance(item[0], Image.Image) + assert item[1] == 0 + + loader.print_summary() + +def test_aiia_data_loader(mocker): + # Mock Image.open to return a fake image + mock_image = Image.new('RGB', (224, 224)) + mocker.patch('PIL.Image.open', return_value=mock_image) + + dataset = create_sample_dataset() + data_loader = AIIADataLoader(dataset, batch_size=2, label_column='label') + + # Test train loader + batch = next(iter(data_loader.train_loader)) + assert isinstance(batch, list) + assert len(batch) == 2 # (images, labels) + assert batch[0].shape[0] == 1 # batch size + +def test_aiia_dataset(): + items = [(Image.new('RGB', (224, 224)), 0), (Image.new('RGB', (224, 224)), 1)] + dataset = AIIADataset(items) + + assert len(dataset) == 2 + + item = dataset[0] + assert isinstance(item[0], torch.Tensor) + assert item[1] == 0 + +def test_aiia_dataset_pre_training(): + items = [(Image.new('RGB', (224, 224)), 'denoise', Image.new('RGB', (224, 224)))] + dataset = AIIADataset(items, pretraining=True) + + assert len(dataset) == 1 + + item = dataset[0] + assert isinstance(item[0], torch.Tensor) + assert isinstance(item[2], str) + +def test_aiia_dataset_invalid_image(): + items = [(Image.new('RGB', (50, 50)), 0)] # Create small image + dataset = AIIADataset(items) + + with pytest.raises(ValueError, match="Invalid image dimensions"): + dataset[0] + +def test_aiia_dataset_invalid_task(): + items = [(Image.new('RGB', (224, 224)), 'invalid_task', Image.new('RGB', (224, 224)))] + dataset = AIIADataset(items, pretraining=True) + + with pytest.raises(ValueError): + dataset[0] + +def test_aiia_data_loader_invalid_column(): + dataset = create_sample_dataset() + with pytest.raises(ValueError, match="Column 'invalid_column' not found"): + AIIADataLoader(dataset, column='invalid_column') + +if __name__ == "__main__": + pytest.main(['-v']) \ No newline at end of file