fixed dataloading

This commit is contained in:
Falko Victor Habel 2025-01-26 16:20:33 +01:00
parent 00168af32d
commit b7dc835a86
2 changed files with 27 additions and 35 deletions

View File

@ -85,21 +85,19 @@ class JPGImageLoader:
print(f"Skipped {self.skipped_count} images due to errors.")
class AIIADataLoader(DataLoader):
class AIIADataLoader:
def __init__(self, dataset,
batch_size=32,
val_split=0.2,
seed=42,
column="file_path",
label_column=None):
super().__init__(dataset)
label_column=None,
**dataloader_kwargs):
self.batch_size = batch_size
self.val_split = val_split
self.seed = seed
random.seed(seed)
# Determine which loader to use based on the dataset's content
# Check if any entry in bytes_column is a bytes or bytestring type
is_bytes_or_bytestring = any(
isinstance(value, (bytes, memoryview))
for value in dataset[column].dropna().head(1).astype(str)
@ -112,10 +110,8 @@ class AIIADataLoader(DataLoader):
label_column=label_column
)
else:
# Check if file_path column contains valid image file paths (at least one entry)
sample_paths = dataset[column].dropna().head(1).astype(str)
# Regex pattern for matching image file paths (adjust as needed)
filepath_pattern = r'.*(?:/|\\).*\.([jJ][pP][gG]|png|gif)$'
if any(
@ -128,23 +124,33 @@ class AIIADataLoader(DataLoader):
label_column=label_column
)
else:
# If neither condition is met, default to JPGImageLoader (assuming bytes are stored as strings)
self.loader = JPGImageLoader(
dataset,
bytes_column=column,
label_column=label_column
)
# Get all items
self.items = [self.loader.get_item(idx) for idx in range(len(dataset))]
# Split into train and validation sets
train_indices, val_indices = self._split_data()
# Create datasets for training and validation
self.train_dataset = self._create_subset(train_indices)
self.val_dataset = self._create_subset(val_indices)
self.train_loader = DataLoader(
self.train_dataset,
batch_size=batch_size,
shuffle=True,
**dataloader_kwargs
)
self.val_loader = DataLoader(
self.val_dataset,
batch_size=batch_size,
shuffle=False,
**dataloader_kwargs
)
def _split_data(self):
if len(self.items) == 0:
return [], []
@ -184,7 +190,6 @@ class AIIADataset(torch.utils.data.Dataset):
return (image, label)
elif isinstance(item, tuple) and len(item) == 3:
image, task, label = item
# Handle tasks accordingly (e.g., apply different augmentations)
if task == 'denoise':
noise_std = 0.1
noisy_img = image + torch.randn_like(image) * noise_std
@ -199,7 +204,6 @@ class AIIADataset(torch.utils.data.Dataset):
else:
raise ValueError(f"Unknown task: {task}")
else:
# Handle single images without labels or tasks
if isinstance(item, Image.Image):
return item
else:

View File

@ -29,29 +29,17 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
model = AIIABase(config)
# Create dataset loader with merged data
train_dataset = AIIADataLoader(
aiia_loader = AIIADataLoader(
merged_df,
batch_size=32,
val_split=0.2,
seed=42,
column="image_bytes",
label_column=None
column="image_bytes"
)
# Create separate dataloaders for training and validation sets
train_dataloader = DataLoader(
train_dataset.train_dataset,
batch_size=train_dataset.batch_size,
shuffle=True,
num_workers=4
)
val_dataloader = DataLoader(
train_dataset.val_ataset,
batch_size=train_dataset.batch_size,
shuffle=False,
num_workers=4
)
# Access the train and validation loaders
train_loader = aiia_loader.train_loader
val_loader = aiia_loader.val_loader
# Initialize loss functions and optimizer
criterion_denoise = nn.MSELoss()
@ -72,7 +60,7 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
model.train()
total_train_loss = 0.0
for batch in train_dataloader:
for batch in train_loader:
images, targets, tasks = zip(*batch)
if device == "cuda":
@ -102,14 +90,14 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
total_train_loss += avg_loss.item()
# Separate losses for reporting (you'd need to track this based on tasks)
avg_total_train_loss = total_train_loss / len(train_dataloader)
avg_total_train_loss = total_train_loss / len(train_loader)
print(f"Training Loss: {avg_total_train_loss:.4f}")
# Validation phase
model.eval()
with torch.no_grad():
val_losses = []
for batch in val_dataloader:
for batch in val_loader:
images, targets, tasks = zip(*batch)
if device == "cuda":
@ -132,7 +120,7 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
avg_val_loss = total_loss / len(images)
val_losses.append(avg_val_loss.item())
avg_val_loss = sum(val_losses) / len(val_dataloader)
avg_val_loss = sum(val_losses) / len(val_loader)
print(f"Validation Loss: {avg_val_loss:.4f}")
# Save the best model