fixed Dataloader and added test cases.

This commit is contained in:
Falko Victor Habel 2025-03-15 23:42:08 +01:00
parent 24a3a7bf56
commit fd518ff080
2 changed files with 163 additions and 41 deletions

View File

@ -14,10 +14,10 @@ class FilePathLoader:
self.label_column = label_column
self.successful_count = 0
self.skipped_count = 0
if self.file_path_column not in dataset.column_names:
if self.file_path_column not in dataset.columns:
raise ValueError(f"Column '{self.file_path_column}' not found in dataset.")
def _get_image(self, item):
try:
path = item[self.file_path_column]
@ -32,7 +32,7 @@ class FilePathLoader:
except Exception as e:
print(f"Error loading image from {path}: {e}")
return None
def get_item(self, idx):
item = self.dataset.iloc[idx]
image = self._get_image(item)
@ -46,7 +46,7 @@ class FilePathLoader:
else:
self.skipped_count += 1
return None
def print_summary(self):
print(f"Successfully converted {self.successful_count} images.")
print(f"Skipped {self.skipped_count} images due to errors.")
@ -58,14 +58,14 @@ class JPGImageLoader:
self.label_column = label_column
self.successful_count = 0
self.skipped_count = 0
if self.bytes_column not in dataset.columns:
raise ValueError(f"Column '{self.bytes_column}' not found in dataset.")
def _get_image(self, item):
try:
data = item[self.bytes_column]
if isinstance(data, str) and data.startswith("b'"):
cleaned_data = data[2:-1].encode('latin1').decode('unicode-escape').encode('latin1')
bytes_data = cleaned_data
@ -73,7 +73,7 @@ class JPGImageLoader:
bytes_data = base64.b64decode(data)
else:
bytes_data = data
img_bytes = io.BytesIO(bytes_data)
image = Image.open(img_bytes)
if image.mode == 'RGBA':
@ -86,7 +86,7 @@ class JPGImageLoader:
except Exception as e:
print(f"Error loading image from bytes: {e}")
return None
def get_item(self, idx):
item = self.dataset.iloc[idx]
image = self._get_image(item)
@ -100,37 +100,41 @@ class JPGImageLoader:
else:
self.skipped_count += 1
return None
def print_summary(self):
print(f"Successfully converted {self.successful_count} images.")
print(f"Skipped {self.skipped_count} images due to errors.")
class AIIADataLoader:
def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path", label_column=None, pretraining=False, **dataloader_kwargs):
def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path",
label_column=None, pretraining=False, **dataloader_kwargs):
if column not in dataset.columns:
raise ValueError(f"Column '{column}' not found in dataset")
self.batch_size = batch_size
self.val_split = val_split
self.seed = seed
self.pretraining = pretraining
random.seed(seed)
sample_value = dataset[column].iloc[0]
is_bytes_or_bytestring = isinstance(sample_value, (bytes, str)) and (
isinstance(sample_value, bytes) or
sample_value.startswith("b'") or
isinstance(sample_value, bytes) or
sample_value.startswith("b'") or
sample_value.startswith(('b"', 'data:image'))
)
if is_bytes_or_bytestring:
self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column)
else:
sample_paths = dataset[column].dropna().head(1).astype(str)
filepath_pattern = r'.*(?:/|\\).*\.([jJ][pP][gG]|[pP][nN][gG]|[gG][iI][fF])$'
if any(re.match(filepath_pattern, path, flags=re.IGNORECASE) for path in sample_paths):
self.loader = FilePathLoader(dataset, file_path_column=column, label_column=label_column)
else:
self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column)
self.items = []
for idx in range(len(dataset)):
item = self.loader.get_item(idx)
@ -141,33 +145,32 @@ class AIIADataLoader:
self.items.append((img, 'rotate', 0))
else:
self.items.append(item)
if not self.items:
raise ValueError("No valid items were loaded from the dataset")
train_indices, val_indices = self._split_data()
self.train_dataset = self._create_subset(train_indices)
self.val_dataset = self._create_subset(val_indices)
self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True, **dataloader_kwargs)
self.val_loader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False, **dataloader_kwargs)
def _split_data(self):
if len(self.items) == 0:
raise ValueError("No items to split")
num_samples = len(self.items)
indices = list(range(num_samples))
random.shuffle(indices)
split_idx = int((1 - self.val_split) * num_samples)
train_indices = indices[:split_idx]
val_indices = indices[split_idx:]
return train_indices, val_indices
def _create_subset(self, indices):
subset_items = [self.items[i] for i in indices]
return AIIADataset(subset_items, pretraining=self.pretraining)
@ -180,22 +183,24 @@ class AIIADataset(torch.utils.data.Dataset):
transforms.Resize((224, 224)),
transforms.ToTensor()
])
def __len__(self):
return len(self.items)
def __getitem__(self, idx):
item = self.items[idx]
if self.pretraining:
image, task, label = item
if not isinstance(image, Image.Image):
raise ValueError(f"Invalid image at index {idx}")
# Check image dimensions before transform
if image.size[0] < 224 or image.size[1] < 224:
raise ValueError("Invalid image dimensions")
image = self.transform(image)
if image.shape != (3, 224, 224):
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
if task == 'denoise':
noise_std = 0.1
noisy_img = image + torch.randn_like(image) * noise_std
@ -214,15 +219,20 @@ class AIIADataset(torch.utils.data.Dataset):
image, label = item
if not isinstance(image, Image.Image):
raise ValueError(f"Invalid image at index {idx}")
# Check image dimensions before transform
if image.size[0] < 224 or image.size[1] < 224:
raise ValueError("Invalid image dimensions")
image = self.transform(image)
if image.shape != (3, 224, 224):
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
return image, label
else:
if isinstance(item, Image.Image):
image = self.transform(item)
else:
image = self.transform(item[0])
if image.shape != (3, 224, 224):
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
image = item[0] if isinstance(item, tuple) else item
if not isinstance(image, Image.Image):
raise ValueError(f"Invalid image at index {idx}")
# Check image dimensions before transform
if image.size[0] < 224 or image.size[1] < 224:
raise ValueError("Invalid image dimensions")
image = self.transform(image)
return image

View File

@ -0,0 +1,112 @@
import pytest
from PIL import Image
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import pandas as pd
import numpy as np
from aiia.data.DataLoader import FilePathLoader, JPGImageLoader, AIIADataLoader, AIIADataset
def create_sample_dataset(file_paths=None):
if file_paths is None:
file_paths = ['path/to/image1.jpg', 'path/to/image2.png']
data = {
'file_path': file_paths,
'label': [0] * len(file_paths) # Match length of labels to file_paths
}
df = pd.DataFrame(data)
return df
def create_sample_bytes_dataset(bytes_data=None):
if bytes_data is None:
bytes_data = [b'fake_image_data_1', b'fake_image_data_2']
data = {
'jpg': bytes_data,
'label': [0] * len(bytes_data) # Match length of labels to bytes_data
}
df = pd.DataFrame(data)
return df
def test_file_path_loader(mocker):
# Mock Image.open to return a fake image
mock_image = Image.new('RGB', (224, 224))
mocker.patch('PIL.Image.open', return_value=mock_image)
dataset = create_sample_dataset()
loader = FilePathLoader(dataset, label_column='label') # Added label_column
item = loader.get_item(0)
assert isinstance(item[0], Image.Image)
assert item[1] == 0
loader.print_summary()
def test_jpg_image_loader(mocker):
# Mock Image.open to return a fake image
mock_image = Image.new('RGB', (224, 224))
mocker.patch('PIL.Image.open', return_value=mock_image)
dataset = create_sample_bytes_dataset()
loader = JPGImageLoader(dataset, label_column='label') # Added label_column
item = loader.get_item(0)
assert isinstance(item[0], Image.Image)
assert item[1] == 0
loader.print_summary()
def test_aiia_data_loader(mocker):
# Mock Image.open to return a fake image
mock_image = Image.new('RGB', (224, 224))
mocker.patch('PIL.Image.open', return_value=mock_image)
dataset = create_sample_dataset()
data_loader = AIIADataLoader(dataset, batch_size=2, label_column='label')
# Test train loader
batch = next(iter(data_loader.train_loader))
assert isinstance(batch, list)
assert len(batch) == 2 # (images, labels)
assert batch[0].shape[0] == 1 # batch size
def test_aiia_dataset():
items = [(Image.new('RGB', (224, 224)), 0), (Image.new('RGB', (224, 224)), 1)]
dataset = AIIADataset(items)
assert len(dataset) == 2
item = dataset[0]
assert isinstance(item[0], torch.Tensor)
assert item[1] == 0
def test_aiia_dataset_pre_training():
items = [(Image.new('RGB', (224, 224)), 'denoise', Image.new('RGB', (224, 224)))]
dataset = AIIADataset(items, pretraining=True)
assert len(dataset) == 1
item = dataset[0]
assert isinstance(item[0], torch.Tensor)
assert isinstance(item[2], str)
def test_aiia_dataset_invalid_image():
items = [(Image.new('RGB', (50, 50)), 0)] # Create small image
dataset = AIIADataset(items)
with pytest.raises(ValueError, match="Invalid image dimensions"):
dataset[0]
def test_aiia_dataset_invalid_task():
items = [(Image.new('RGB', (224, 224)), 'invalid_task', Image.new('RGB', (224, 224)))]
dataset = AIIADataset(items, pretraining=True)
with pytest.raises(ValueError):
dataset[0]
def test_aiia_data_loader_invalid_column():
dataset = create_sample_dataset()
with pytest.raises(ValueError, match="Column 'invalid_column' not found"):
AIIADataLoader(dataset, column='invalid_column')
if __name__ == "__main__":
pytest.main(['-v'])