From 7c4aef09789a14f42f42b616f0bebca2aaddba76 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Sun, 26 Jan 2025 22:48:29 +0100 Subject: [PATCH] updated dataloader to work with tupels --- src/aiia/data/DataLoader.py | 44 ++++++++++------- src/pretrain.py | 94 ++++++++++++++----------------------- 2 files changed, 61 insertions(+), 77 deletions(-) diff --git a/src/aiia/data/DataLoader.py b/src/aiia/data/DataLoader.py index 78f4eb5..567954f 100644 --- a/src/aiia/data/DataLoader.py +++ b/src/aiia/data/DataLoader.py @@ -106,10 +106,11 @@ class JPGImageLoader: print(f"Skipped {self.skipped_count} images due to errors.") class AIIADataLoader: - def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path", label_column=None, **dataloader_kwargs): + def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path", label_column=None, pretraining=False, **dataloader_kwargs): self.batch_size = batch_size self.val_split = val_split self.seed = seed + self.pretraining = pretraining random.seed(seed) sample_value = dataset[column].iloc[0] @@ -134,7 +135,12 @@ class AIIADataLoader: for idx in range(len(dataset)): item = self.loader.get_item(idx) if item is not None: - self.items.append(item) + if self.pretraining: + img = item[0] if isinstance(item, tuple) else item + self.items.append((img, 'denoise', img)) + self.items.append((img, 'rotate', 0)) + else: + self.items.append(item) if not self.items: raise ValueError("No valid items were loaded from the dataset") @@ -163,12 +169,14 @@ class AIIADataLoader: def _create_subset(self, indices): subset_items = [self.items[i] for i in indices] - return AIIADataset(subset_items) + return AIIADataset(subset_items, pretraining=self.pretraining) class AIIADataset(torch.utils.data.Dataset): - def __init__(self, items): + def __init__(self, items, pretraining=False): self.items = items + self.pretraining = pretraining self.transform = transforms.Compose([ + transforms.Resize((224, 224)), transforms.ToTensor() ]) @@ -177,29 +185,29 @@ class AIIADataset(torch.utils.data.Dataset): def __getitem__(self, idx): item = self.items[idx] - if isinstance(item, tuple) and len(item) == 2: - image, label = item - image = self.transform(image) - return (image, label) - elif isinstance(item, tuple) and len(item) == 3: + + if self.pretraining: image, task, label = item image = self.transform(image) if task == 'denoise': noise_std = 0.1 noisy_img = image + torch.randn_like(image) * noise_std - target = image - return (noisy_img, target, task) + target = image.clone() + return noisy_img, target, task elif task == 'rotate': angles = [0, 90, 180, 270] angle = random.choice(angles) rotated_img = transforms.functional.rotate(image, angle) - target = torch.tensor(angle).long() - return (rotated_img, target, task) - else: - raise ValueError(f"Unknown task: {task}") + target = torch.tensor(angle / 90).long() + return rotated_img, target, task else: - if isinstance(item, Image.Image): - return self.transform(item) + if isinstance(item, tuple) and len(item) == 2: + image, label = item + image = self.transform(image) + return image, label else: - raise ValueError("Invalid item format.") + if isinstance(item, Image.Image): + return self.transform(item) + else: + return self.transform(item[0]) diff --git a/src/pretrain.py b/src/pretrain.py index 78ce63a..8436d51 100644 --- a/src/pretrain.py +++ b/src/pretrain.py @@ -1,37 +1,26 @@ import torch from torch import nn -from torch.utils.data import Dataset, DataLoader -from torchvision import transforms -from PIL import Image -import os -import random import pandas as pd from aiia.model.config import AIIAConfig from aiia.model import AIIABase from aiia.data.DataLoader import AIIADataLoader def pretrain_model(data_path1, data_path2, num_epochs=3): - # Merge the two parquet files df1 = pd.read_parquet(data_path1).head(10000) df2 = pd.read_parquet(data_path2).head(10000) merged_df = pd.concat([df1, df2], ignore_index=True) - # Create a new AIIAConfig instance config = AIIAConfig( model_name="AIIA-Base-512x20k", ) - # Initialize the base model model = AIIABase(config) - # Create dataset loader with merged data - aiia_loader = AIIADataLoader(merged_df, column="image_bytes", batch_size=32) + aiia_loader = AIIADataLoader(merged_df, column="image_bytes", batch_size=32, pretraining=True) - # Access the train and validation loaders train_loader = aiia_loader.train_loader val_loader = aiia_loader.val_loader - # Initialize loss functions and optimizer criterion_denoise = nn.MSELoss() criterion_rotate = nn.CrossEntropyLoss() @@ -46,74 +35,61 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): print(f"\nEpoch {epoch+1}/{num_epochs}") print("-" * 20) - # Training phase model.train() total_train_loss = 0.0 + denoise_losses = [] + rotate_losses = [] for batch in train_loader: - images, targets, tasks = zip(*batch) - - if device == "cuda": - images = [img.cuda() for img in images] - targets = [t.cuda() for t in targets] + noisy_imgs, targets, tasks = batch + noisy_imgs = noisy_imgs.to(device) + targets = targets.to(device) optimizer.zero_grad() - # Process each sample individually since tasks can vary - outputs = [] - total_loss = 0.0 - for i, (image, target, task) in enumerate(zip(images, targets, tasks)): - output = model(image.unsqueeze(0)) - + outputs = model(noisy_imgs) + task_losses = [] + for i, task in enumerate(tasks): if task == 'denoise': - loss = criterion_denoise(output.squeeze(), target) - elif task == 'rotate': - loss = criterion_rotate(output.view(-1, len(set(outputs))), target) - - total_loss += loss - outputs.append(output) - - avg_loss = total_loss / len(images) - avg_loss.backward() + loss = criterion_denoise(outputs[i], targets[i]) + denoise_losses.append(loss.item()) + else: + loss = criterion_rotate(outputs[i].unsqueeze(0), targets[i].unsqueeze(0)) + rotate_losses.append(loss.item()) + task_losses.append(loss) + + batch_loss = sum(task_losses) / len(task_losses) + batch_loss.backward() optimizer.step() - total_train_loss += avg_loss.item() - # Separate losses for reporting (you'd need to track this based on tasks) - + total_train_loss += batch_loss.item() avg_total_train_loss = total_train_loss / len(train_loader) print(f"Training Loss: {avg_total_train_loss:.4f}") - # Validation phase model.eval() with torch.no_grad(): val_losses = [] for batch in val_loader: - images, targets, tasks = zip(*batch) + noisy_imgs, targets, tasks = batch - if device == "cuda": - images = [img.cuda() for img in images] - targets = [t.cuda() for t in targets] - - outputs = [] - total_loss = 0.0 - for i, (image, target, task) in enumerate(zip(images, targets, tasks)): - output = model(image.unsqueeze(0)) - + noisy_imgs = noisy_imgs.to(device) + targets = targets.to(device) + + outputs = model(noisy_imgs) + + task_losses = [] + for i, task in enumerate(tasks): if task == 'denoise': - loss = criterion_denoise(output.squeeze(), target) - elif task == 'rotate': - loss = criterion_rotate(output.view(-1, len(set(outputs))), target) - - total_loss += loss - outputs.append(output) - - avg_val_loss = total_loss / len(images) - val_losses.append(avg_val_loss.item()) - + loss = criterion_denoise(outputs[i], targets[i]) + else: + loss = criterion_rotate(outputs[i].unsqueeze(0), targets[i].unsqueeze(0)) + task_losses.append(loss) + + batch_loss = sum(task_losses) / len(task_losses) + val_losses.append(batch_loss.item()) avg_val_loss = sum(val_losses) / len(val_loader) print(f"Validation Loss: {avg_val_loss:.4f}") - # Save the best model if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss model.save("BASEv0.1") @@ -122,4 +98,4 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): if __name__ == "__main__": data_path1 = "/root/training_data/vision-dataset/images_checkpoint.parquet" data_path2 = "/root/training_data/vision-dataset/vec_images_dataset.parquet" - pretrain_model(data_path1, data_path2, num_epochs=8) \ No newline at end of file + pretrain_model(data_path1, data_path2, num_epochs=3) \ No newline at end of file