fixed Dataloader and added test cases.
This commit is contained in:
parent
24a3a7bf56
commit
fd518ff080
|
@ -14,10 +14,10 @@ class FilePathLoader:
|
||||||
self.label_column = label_column
|
self.label_column = label_column
|
||||||
self.successful_count = 0
|
self.successful_count = 0
|
||||||
self.skipped_count = 0
|
self.skipped_count = 0
|
||||||
|
|
||||||
if self.file_path_column not in dataset.column_names:
|
if self.file_path_column not in dataset.columns:
|
||||||
raise ValueError(f"Column '{self.file_path_column}' not found in dataset.")
|
raise ValueError(f"Column '{self.file_path_column}' not found in dataset.")
|
||||||
|
|
||||||
def _get_image(self, item):
|
def _get_image(self, item):
|
||||||
try:
|
try:
|
||||||
path = item[self.file_path_column]
|
path = item[self.file_path_column]
|
||||||
|
@ -32,7 +32,7 @@ class FilePathLoader:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error loading image from {path}: {e}")
|
print(f"Error loading image from {path}: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_item(self, idx):
|
def get_item(self, idx):
|
||||||
item = self.dataset.iloc[idx]
|
item = self.dataset.iloc[idx]
|
||||||
image = self._get_image(item)
|
image = self._get_image(item)
|
||||||
|
@ -46,7 +46,7 @@ class FilePathLoader:
|
||||||
else:
|
else:
|
||||||
self.skipped_count += 1
|
self.skipped_count += 1
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def print_summary(self):
|
def print_summary(self):
|
||||||
print(f"Successfully converted {self.successful_count} images.")
|
print(f"Successfully converted {self.successful_count} images.")
|
||||||
print(f"Skipped {self.skipped_count} images due to errors.")
|
print(f"Skipped {self.skipped_count} images due to errors.")
|
||||||
|
@ -58,14 +58,14 @@ class JPGImageLoader:
|
||||||
self.label_column = label_column
|
self.label_column = label_column
|
||||||
self.successful_count = 0
|
self.successful_count = 0
|
||||||
self.skipped_count = 0
|
self.skipped_count = 0
|
||||||
|
|
||||||
if self.bytes_column not in dataset.columns:
|
if self.bytes_column not in dataset.columns:
|
||||||
raise ValueError(f"Column '{self.bytes_column}' not found in dataset.")
|
raise ValueError(f"Column '{self.bytes_column}' not found in dataset.")
|
||||||
|
|
||||||
def _get_image(self, item):
|
def _get_image(self, item):
|
||||||
try:
|
try:
|
||||||
data = item[self.bytes_column]
|
data = item[self.bytes_column]
|
||||||
|
|
||||||
if isinstance(data, str) and data.startswith("b'"):
|
if isinstance(data, str) and data.startswith("b'"):
|
||||||
cleaned_data = data[2:-1].encode('latin1').decode('unicode-escape').encode('latin1')
|
cleaned_data = data[2:-1].encode('latin1').decode('unicode-escape').encode('latin1')
|
||||||
bytes_data = cleaned_data
|
bytes_data = cleaned_data
|
||||||
|
@ -73,7 +73,7 @@ class JPGImageLoader:
|
||||||
bytes_data = base64.b64decode(data)
|
bytes_data = base64.b64decode(data)
|
||||||
else:
|
else:
|
||||||
bytes_data = data
|
bytes_data = data
|
||||||
|
|
||||||
img_bytes = io.BytesIO(bytes_data)
|
img_bytes = io.BytesIO(bytes_data)
|
||||||
image = Image.open(img_bytes)
|
image = Image.open(img_bytes)
|
||||||
if image.mode == 'RGBA':
|
if image.mode == 'RGBA':
|
||||||
|
@ -86,7 +86,7 @@ class JPGImageLoader:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error loading image from bytes: {e}")
|
print(f"Error loading image from bytes: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_item(self, idx):
|
def get_item(self, idx):
|
||||||
item = self.dataset.iloc[idx]
|
item = self.dataset.iloc[idx]
|
||||||
image = self._get_image(item)
|
image = self._get_image(item)
|
||||||
|
@ -100,37 +100,41 @@ class JPGImageLoader:
|
||||||
else:
|
else:
|
||||||
self.skipped_count += 1
|
self.skipped_count += 1
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def print_summary(self):
|
def print_summary(self):
|
||||||
print(f"Successfully converted {self.successful_count} images.")
|
print(f"Successfully converted {self.successful_count} images.")
|
||||||
print(f"Skipped {self.skipped_count} images due to errors.")
|
print(f"Skipped {self.skipped_count} images due to errors.")
|
||||||
|
|
||||||
class AIIADataLoader:
|
class AIIADataLoader:
|
||||||
def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path", label_column=None, pretraining=False, **dataloader_kwargs):
|
def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path",
|
||||||
|
label_column=None, pretraining=False, **dataloader_kwargs):
|
||||||
|
if column not in dataset.columns:
|
||||||
|
raise ValueError(f"Column '{column}' not found in dataset")
|
||||||
|
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.val_split = val_split
|
self.val_split = val_split
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.pretraining = pretraining
|
self.pretraining = pretraining
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
|
|
||||||
sample_value = dataset[column].iloc[0]
|
sample_value = dataset[column].iloc[0]
|
||||||
is_bytes_or_bytestring = isinstance(sample_value, (bytes, str)) and (
|
is_bytes_or_bytestring = isinstance(sample_value, (bytes, str)) and (
|
||||||
isinstance(sample_value, bytes) or
|
isinstance(sample_value, bytes) or
|
||||||
sample_value.startswith("b'") or
|
sample_value.startswith("b'") or
|
||||||
sample_value.startswith(('b"', 'data:image'))
|
sample_value.startswith(('b"', 'data:image'))
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_bytes_or_bytestring:
|
if is_bytes_or_bytestring:
|
||||||
self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column)
|
self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column)
|
||||||
else:
|
else:
|
||||||
sample_paths = dataset[column].dropna().head(1).astype(str)
|
sample_paths = dataset[column].dropna().head(1).astype(str)
|
||||||
filepath_pattern = r'.*(?:/|\\).*\.([jJ][pP][gG]|[pP][nN][gG]|[gG][iI][fF])$'
|
filepath_pattern = r'.*(?:/|\\).*\.([jJ][pP][gG]|[pP][nN][gG]|[gG][iI][fF])$'
|
||||||
|
|
||||||
if any(re.match(filepath_pattern, path, flags=re.IGNORECASE) for path in sample_paths):
|
if any(re.match(filepath_pattern, path, flags=re.IGNORECASE) for path in sample_paths):
|
||||||
self.loader = FilePathLoader(dataset, file_path_column=column, label_column=label_column)
|
self.loader = FilePathLoader(dataset, file_path_column=column, label_column=label_column)
|
||||||
else:
|
else:
|
||||||
self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column)
|
self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column)
|
||||||
|
|
||||||
self.items = []
|
self.items = []
|
||||||
for idx in range(len(dataset)):
|
for idx in range(len(dataset)):
|
||||||
item = self.loader.get_item(idx)
|
item = self.loader.get_item(idx)
|
||||||
|
@ -141,33 +145,32 @@ class AIIADataLoader:
|
||||||
self.items.append((img, 'rotate', 0))
|
self.items.append((img, 'rotate', 0))
|
||||||
else:
|
else:
|
||||||
self.items.append(item)
|
self.items.append(item)
|
||||||
|
|
||||||
if not self.items:
|
if not self.items:
|
||||||
raise ValueError("No valid items were loaded from the dataset")
|
raise ValueError("No valid items were loaded from the dataset")
|
||||||
|
|
||||||
|
|
||||||
train_indices, val_indices = self._split_data()
|
train_indices, val_indices = self._split_data()
|
||||||
|
|
||||||
self.train_dataset = self._create_subset(train_indices)
|
self.train_dataset = self._create_subset(train_indices)
|
||||||
self.val_dataset = self._create_subset(val_indices)
|
self.val_dataset = self._create_subset(val_indices)
|
||||||
|
|
||||||
self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True, **dataloader_kwargs)
|
self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True, **dataloader_kwargs)
|
||||||
self.val_loader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False, **dataloader_kwargs)
|
self.val_loader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False, **dataloader_kwargs)
|
||||||
|
|
||||||
def _split_data(self):
|
def _split_data(self):
|
||||||
if len(self.items) == 0:
|
if len(self.items) == 0:
|
||||||
raise ValueError("No items to split")
|
raise ValueError("No items to split")
|
||||||
|
|
||||||
num_samples = len(self.items)
|
num_samples = len(self.items)
|
||||||
indices = list(range(num_samples))
|
indices = list(range(num_samples))
|
||||||
random.shuffle(indices)
|
random.shuffle(indices)
|
||||||
|
|
||||||
split_idx = int((1 - self.val_split) * num_samples)
|
split_idx = int((1 - self.val_split) * num_samples)
|
||||||
train_indices = indices[:split_idx]
|
train_indices = indices[:split_idx]
|
||||||
val_indices = indices[split_idx:]
|
val_indices = indices[split_idx:]
|
||||||
|
|
||||||
return train_indices, val_indices
|
return train_indices, val_indices
|
||||||
|
|
||||||
def _create_subset(self, indices):
|
def _create_subset(self, indices):
|
||||||
subset_items = [self.items[i] for i in indices]
|
subset_items = [self.items[i] for i in indices]
|
||||||
return AIIADataset(subset_items, pretraining=self.pretraining)
|
return AIIADataset(subset_items, pretraining=self.pretraining)
|
||||||
|
@ -180,22 +183,24 @@ class AIIADataset(torch.utils.data.Dataset):
|
||||||
transforms.Resize((224, 224)),
|
transforms.Resize((224, 224)),
|
||||||
transforms.ToTensor()
|
transforms.ToTensor()
|
||||||
])
|
])
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.items)
|
return len(self.items)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
item = self.items[idx]
|
item = self.items[idx]
|
||||||
|
|
||||||
if self.pretraining:
|
if self.pretraining:
|
||||||
image, task, label = item
|
image, task, label = item
|
||||||
if not isinstance(image, Image.Image):
|
if not isinstance(image, Image.Image):
|
||||||
raise ValueError(f"Invalid image at index {idx}")
|
raise ValueError(f"Invalid image at index {idx}")
|
||||||
|
|
||||||
|
# Check image dimensions before transform
|
||||||
|
if image.size[0] < 224 or image.size[1] < 224:
|
||||||
|
raise ValueError("Invalid image dimensions")
|
||||||
|
|
||||||
image = self.transform(image)
|
image = self.transform(image)
|
||||||
if image.shape != (3, 224, 224):
|
|
||||||
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
|
|
||||||
|
|
||||||
if task == 'denoise':
|
if task == 'denoise':
|
||||||
noise_std = 0.1
|
noise_std = 0.1
|
||||||
noisy_img = image + torch.randn_like(image) * noise_std
|
noisy_img = image + torch.randn_like(image) * noise_std
|
||||||
|
@ -214,15 +219,20 @@ class AIIADataset(torch.utils.data.Dataset):
|
||||||
image, label = item
|
image, label = item
|
||||||
if not isinstance(image, Image.Image):
|
if not isinstance(image, Image.Image):
|
||||||
raise ValueError(f"Invalid image at index {idx}")
|
raise ValueError(f"Invalid image at index {idx}")
|
||||||
|
|
||||||
|
# Check image dimensions before transform
|
||||||
|
if image.size[0] < 224 or image.size[1] < 224:
|
||||||
|
raise ValueError("Invalid image dimensions")
|
||||||
|
|
||||||
image = self.transform(image)
|
image = self.transform(image)
|
||||||
if image.shape != (3, 224, 224):
|
|
||||||
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
|
|
||||||
return image, label
|
return image, label
|
||||||
else:
|
else:
|
||||||
if isinstance(item, Image.Image):
|
image = item[0] if isinstance(item, tuple) else item
|
||||||
image = self.transform(item)
|
if not isinstance(image, Image.Image):
|
||||||
else:
|
raise ValueError(f"Invalid image at index {idx}")
|
||||||
image = self.transform(item[0])
|
|
||||||
if image.shape != (3, 224, 224):
|
# Check image dimensions before transform
|
||||||
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
|
if image.size[0] < 224 or image.size[1] < 224:
|
||||||
|
raise ValueError("Invalid image dimensions")
|
||||||
|
image = self.transform(image)
|
||||||
return image
|
return image
|
||||||
|
|
|
@ -0,0 +1,112 @@
|
||||||
|
import pytest
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torchvision import transforms
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from aiia.data.DataLoader import FilePathLoader, JPGImageLoader, AIIADataLoader, AIIADataset
|
||||||
|
|
||||||
|
def create_sample_dataset(file_paths=None):
|
||||||
|
if file_paths is None:
|
||||||
|
file_paths = ['path/to/image1.jpg', 'path/to/image2.png']
|
||||||
|
|
||||||
|
data = {
|
||||||
|
'file_path': file_paths,
|
||||||
|
'label': [0] * len(file_paths) # Match length of labels to file_paths
|
||||||
|
}
|
||||||
|
df = pd.DataFrame(data)
|
||||||
|
return df
|
||||||
|
|
||||||
|
def create_sample_bytes_dataset(bytes_data=None):
|
||||||
|
if bytes_data is None:
|
||||||
|
bytes_data = [b'fake_image_data_1', b'fake_image_data_2']
|
||||||
|
|
||||||
|
data = {
|
||||||
|
'jpg': bytes_data,
|
||||||
|
'label': [0] * len(bytes_data) # Match length of labels to bytes_data
|
||||||
|
}
|
||||||
|
df = pd.DataFrame(data)
|
||||||
|
return df
|
||||||
|
|
||||||
|
def test_file_path_loader(mocker):
|
||||||
|
# Mock Image.open to return a fake image
|
||||||
|
mock_image = Image.new('RGB', (224, 224))
|
||||||
|
mocker.patch('PIL.Image.open', return_value=mock_image)
|
||||||
|
|
||||||
|
dataset = create_sample_dataset()
|
||||||
|
loader = FilePathLoader(dataset, label_column='label') # Added label_column
|
||||||
|
item = loader.get_item(0)
|
||||||
|
assert isinstance(item[0], Image.Image)
|
||||||
|
assert item[1] == 0
|
||||||
|
|
||||||
|
loader.print_summary()
|
||||||
|
|
||||||
|
def test_jpg_image_loader(mocker):
|
||||||
|
# Mock Image.open to return a fake image
|
||||||
|
mock_image = Image.new('RGB', (224, 224))
|
||||||
|
mocker.patch('PIL.Image.open', return_value=mock_image)
|
||||||
|
|
||||||
|
dataset = create_sample_bytes_dataset()
|
||||||
|
loader = JPGImageLoader(dataset, label_column='label') # Added label_column
|
||||||
|
item = loader.get_item(0)
|
||||||
|
assert isinstance(item[0], Image.Image)
|
||||||
|
assert item[1] == 0
|
||||||
|
|
||||||
|
loader.print_summary()
|
||||||
|
|
||||||
|
def test_aiia_data_loader(mocker):
|
||||||
|
# Mock Image.open to return a fake image
|
||||||
|
mock_image = Image.new('RGB', (224, 224))
|
||||||
|
mocker.patch('PIL.Image.open', return_value=mock_image)
|
||||||
|
|
||||||
|
dataset = create_sample_dataset()
|
||||||
|
data_loader = AIIADataLoader(dataset, batch_size=2, label_column='label')
|
||||||
|
|
||||||
|
# Test train loader
|
||||||
|
batch = next(iter(data_loader.train_loader))
|
||||||
|
assert isinstance(batch, list)
|
||||||
|
assert len(batch) == 2 # (images, labels)
|
||||||
|
assert batch[0].shape[0] == 1 # batch size
|
||||||
|
|
||||||
|
def test_aiia_dataset():
|
||||||
|
items = [(Image.new('RGB', (224, 224)), 0), (Image.new('RGB', (224, 224)), 1)]
|
||||||
|
dataset = AIIADataset(items)
|
||||||
|
|
||||||
|
assert len(dataset) == 2
|
||||||
|
|
||||||
|
item = dataset[0]
|
||||||
|
assert isinstance(item[0], torch.Tensor)
|
||||||
|
assert item[1] == 0
|
||||||
|
|
||||||
|
def test_aiia_dataset_pre_training():
|
||||||
|
items = [(Image.new('RGB', (224, 224)), 'denoise', Image.new('RGB', (224, 224)))]
|
||||||
|
dataset = AIIADataset(items, pretraining=True)
|
||||||
|
|
||||||
|
assert len(dataset) == 1
|
||||||
|
|
||||||
|
item = dataset[0]
|
||||||
|
assert isinstance(item[0], torch.Tensor)
|
||||||
|
assert isinstance(item[2], str)
|
||||||
|
|
||||||
|
def test_aiia_dataset_invalid_image():
|
||||||
|
items = [(Image.new('RGB', (50, 50)), 0)] # Create small image
|
||||||
|
dataset = AIIADataset(items)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Invalid image dimensions"):
|
||||||
|
dataset[0]
|
||||||
|
|
||||||
|
def test_aiia_dataset_invalid_task():
|
||||||
|
items = [(Image.new('RGB', (224, 224)), 'invalid_task', Image.new('RGB', (224, 224)))]
|
||||||
|
dataset = AIIADataset(items, pretraining=True)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
dataset[0]
|
||||||
|
|
||||||
|
def test_aiia_data_loader_invalid_column():
|
||||||
|
dataset = create_sample_dataset()
|
||||||
|
with pytest.raises(ValueError, match="Column 'invalid_column' not found"):
|
||||||
|
AIIADataLoader(dataset, column='invalid_column')
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main(['-v'])
|
Loading…
Reference in New Issue