uupdated DataLoader for new AIIA CNN arch

This commit is contained in:
Falko Victor Habel 2025-01-21 21:19:37 +01:00
parent 99c3ec38c7
commit 106539f48a
1 changed files with 194 additions and 85 deletions

View File

@ -1,97 +1,206 @@
import os
import io
from PIL import Image
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import random
import numpy as np
from sklearn.model_selection import train_test_split
import re
class AIIADataLoader:
def __init__(self, data_dir, batch_size=32, val_split=0.2, seed=42, limit=None):
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
class FilePathLoader:
def __init__(self, dataset, file_path_column="file_path", label_column=None):
self.dataset = dataset
self.file_path_column = file_path_column
self.label_column = label_column
self.successful_count = 0
self.skipped_count = 0
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {self.device}')
image_paths = self._load_image_paths(data_dir, limit=limit)
train_paths, val_paths = self._split_data(image_paths, val_split=val_split)
# Create combined dataset for training with both denoise and rotate tasks
train_denoise_paths = [(path, 'denoise') for path in train_paths]
train_rotate_paths = [(path, 'rotate') for path in train_paths]
train_combined = train_denoise_paths + train_rotate_paths
val_denoise_paths = [(path, 'denoise') for path in val_paths]
val_rotate_paths = [(path, 'rotate') for path in val_paths]
val_combined = val_denoise_paths + val_rotate_paths
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()
])
self.train_dataset = AIIADataset(train_combined, transform=transform)
self.val_dataset = AIIADataset(val_combined, transform=transform)
self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True)
self.val_loader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False)
def _load_image_paths(self, data_dir, limit=None):
extensions = ('.png', '.jpeg', '.jpg')
image_paths = []
for root, dirs, files in os.walk(data_dir):
for file in files:
if file.lower().endswith(extensions):
image_paths.append(os.path.join(root, file))
image_paths = sorted(list(set(image_paths)))
if limit is not None:
image_paths = image_paths[:limit]
return image_paths
def _split_data(self, image_paths, val_split=0.2):
train_paths, val_paths = train_test_split(
image_paths, test_size=val_split, random_state=42
)
return train_paths, val_paths
if self.file_path_column not in dataset.column_names:
raise ValueError(f"Column '{self.file_path_column}' not found in dataset.")
def _get_image(self, item):
try:
path = item[self.file_path_column]
image = Image.open(path).convert("RGB")
return image
except Exception as e:
print(f"Error loading image from {path}: {e}")
return None
class AIIADataset(torch.utils.data.Dataset):
def __init__(self, data_paths, transform=None, preload=False):
self.data_paths = data_paths
self.transform = transform
self.preload = preload
self.loaded_images = {}
if self.preload:
for path, task in self.data_paths:
img = Image.open(path).convert('RGB')
self.loaded_images[path] = img
def __len__(self):
return len(self.data_paths)
def __getitem__(self, idx):
path, task = self.data_paths[idx]
if self.preload:
img = self.loaded_images[path]
def get_item(self, idx):
item = self.dataset[idx]
image = self._get_image(item)
if image is not None:
self.successful_count += 1
if self.label_column is not None:
label = item.get(self.label_column)
return (image, label)
else:
return (image,)
else:
img = Image.open(path).convert('RGB')
self.skipped_count += 1
return None
def print_summary(self):
print(f"Successfully converted {self.successful_count} images.")
print(f"Skipped {self.skipped_count} images due to errors.")
class JPGImageLoader:
def __init__(self, dataset, bytes_column="jpg", label_column=None):
self.dataset = dataset
self.bytes_column = bytes_column
self.label_column = label_column
self.successful_count = 0
self.skipped_count = 0
if task == 'denoise':
noise_std = 0.1
noisy_img = img + torch.randn_like(img) * noise_std
target = img
return noisy_img, target, task
elif task == 'rotate':
angles = [0, 90, 180, 270]
angle = random.choice(angles)
rotated_img = transforms.functional.rotate(img, angle)
target = torch.tensor(angle).long()
return rotated_img, target, task
if self.bytes_column not in dataset.column_names:
raise ValueError(f"Column '{self.bytes_column}' not found in dataset.")
def _get_image(self, item):
try:
bytes_data = item[self.bytes_column]
img_bytes = io.BytesIO(bytes_data)
image = Image.open(img_bytes).convert("RGB")
return image
except Exception as e:
print(f"Error loading image from bytes: {e}")
return None
def get_item(self, idx):
item = self.dataset[idx]
image = self._get_image(item)
if image is not None:
self.successful_count += 1
if self.label_column is not None:
label = item.get(self.label_column)
return (image, label)
else:
return (image,)
else:
raise ValueError(f"Unknown task: {task}")
self.skipped_count += 1
return None
def print_summary(self):
print(f"Successfully converted {self.successful_count} images.")
print(f"Skipped {self.skipped_count} images due to errors.")
class AIIADataLoader(DataLoader):
def __init__(self, dataset,
batch_size=32,
val_split=0.2,
seed=42,
column="file_path",
label_column=None):
super().__init__()
self.batch_size = batch_size
self.val_split = val_split
self.seed = seed
# Determine which loader to use based on the dataset's content
# Check if any entry in bytes_column is a bytes or bytestring type
is_bytes_or_bytestring = any(
isinstance(value, (bytes, memoryview))
for value in dataset[column].dropna().head(1).astype(str)
)
if is_bytes_or_bytestring:
self.loader = JPGImageLoader(
dataset,
bytes_column=column,
label_column=label_column
)
else:
# Check if file_path column contains valid image file paths (at least one entry)
sample_paths = dataset[column].dropna().head(1).astype(str)
# Regex pattern for matching image file paths (adjust as needed)
filepath_pattern = r'.*(?:/|\\).*\.([jJ][pP][gG]|png|gif)$'
if any(
re.match(filepath_pattern, path, flags=re.IGNORECASE)
for path in sample_paths
):
self.loader = FilePathLoader(
dataset,
file_path_column=column,
label_column=label_column
)
else:
# If neither condition is met, default to JPGImageLoader (assuming bytes are stored as strings)
self.loader = JPGImageLoader(
dataset,
bytes_column=column,
label_column=label_column
)
# Get all items
self.items = [self.loader.get_item(idx) for idx in range(len(dataset))]
# Split into train and validation sets
train_indices, val_indices = self._split_data()
# Create datasets for training and validation
self.train_dataset = self._create_subset(train_indices)
self.val_dataset = self._create_subset(val_indices)
def _split_data(self):
if len(self.items) == 0:
return [], []
tasks = [item[1] if len(item) > 1 and hasattr(item, '__getitem__') else None for item in self.items]
unique_tasks = list(set(tasks)) if tasks.count(None) < len(tasks) else []
train_indices = []
val_indices = []
for task in unique_tasks:
task_indices = [i for i, t in enumerate(tasks) if t == task]
n_val = int(len(task_indices) * self.val_split)
random.shuffle(task_indices)
val_indices.extend(task_indices[:n_val])
train_indices.extend(task_indices[n_val:])
return train_indices, val_indices
def _create_subset(self, indices):
subset_items = [self.items[i] for i in indices]
return AIIADataset(subset_items)
class AIIADataset(torch.utils.data.Dataset):
def __init__(self, items):
self.items = items
def __len__(self):
return len(self.items)
def __getitem__(self, idx):
item = self.items[idx]
if isinstance(item, tuple) and len(item) == 2:
image, label = item
return (image, label)
elif isinstance(item, tuple) and len(item) == 3:
image, task, label = item
# Handle tasks accordingly (e.g., apply different augmentations)
if task == 'denoise':
noise_std = 0.1
noisy_img = image + torch.randn_like(image) * noise_std
target = image
return (noisy_img, target, task)
elif task == 'rotate':
angles = [0, 90, 180, 270]
angle = random.choice(angles)
rotated_img = transforms.functional.rotate(image, angle)
target = torch.tensor(angle).long()
return (rotated_img, target, task)
else:
raise ValueError(f"Unknown task: {task}")
else:
# Handle single images without labels or tasks
if isinstance(item, Image.Image):
return item
else:
raise ValueError("Invalid item format.")