updated dataloader to work with tupels

This commit is contained in:
Falko Victor Habel 2025-01-26 22:48:29 +01:00
parent 3f6e6514a9
commit 7c4aef0978
2 changed files with 61 additions and 77 deletions

View File

@ -106,10 +106,11 @@ class JPGImageLoader:
print(f"Skipped {self.skipped_count} images due to errors.")
class AIIADataLoader:
def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path", label_column=None, **dataloader_kwargs):
def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path", label_column=None, pretraining=False, **dataloader_kwargs):
self.batch_size = batch_size
self.val_split = val_split
self.seed = seed
self.pretraining = pretraining
random.seed(seed)
sample_value = dataset[column].iloc[0]
@ -134,7 +135,12 @@ class AIIADataLoader:
for idx in range(len(dataset)):
item = self.loader.get_item(idx)
if item is not None:
self.items.append(item)
if self.pretraining:
img = item[0] if isinstance(item, tuple) else item
self.items.append((img, 'denoise', img))
self.items.append((img, 'rotate', 0))
else:
self.items.append(item)
if not self.items:
raise ValueError("No valid items were loaded from the dataset")
@ -163,12 +169,14 @@ class AIIADataLoader:
def _create_subset(self, indices):
subset_items = [self.items[i] for i in indices]
return AIIADataset(subset_items)
return AIIADataset(subset_items, pretraining=self.pretraining)
class AIIADataset(torch.utils.data.Dataset):
def __init__(self, items):
def __init__(self, items, pretraining=False):
self.items = items
self.pretraining = pretraining
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
@ -177,29 +185,29 @@ class AIIADataset(torch.utils.data.Dataset):
def __getitem__(self, idx):
item = self.items[idx]
if isinstance(item, tuple) and len(item) == 2:
image, label = item
image = self.transform(image)
return (image, label)
elif isinstance(item, tuple) and len(item) == 3:
if self.pretraining:
image, task, label = item
image = self.transform(image)
if task == 'denoise':
noise_std = 0.1
noisy_img = image + torch.randn_like(image) * noise_std
target = image
return (noisy_img, target, task)
target = image.clone()
return noisy_img, target, task
elif task == 'rotate':
angles = [0, 90, 180, 270]
angle = random.choice(angles)
rotated_img = transforms.functional.rotate(image, angle)
target = torch.tensor(angle).long()
return (rotated_img, target, task)
else:
raise ValueError(f"Unknown task: {task}")
target = torch.tensor(angle / 90).long()
return rotated_img, target, task
else:
if isinstance(item, Image.Image):
return self.transform(item)
if isinstance(item, tuple) and len(item) == 2:
image, label = item
image = self.transform(image)
return image, label
else:
raise ValueError("Invalid item format.")
if isinstance(item, Image.Image):
return self.transform(item)
else:
return self.transform(item[0])

View File

@ -1,37 +1,26 @@
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import random
import pandas as pd
from aiia.model.config import AIIAConfig
from aiia.model import AIIABase
from aiia.data.DataLoader import AIIADataLoader
def pretrain_model(data_path1, data_path2, num_epochs=3):
# Merge the two parquet files
df1 = pd.read_parquet(data_path1).head(10000)
df2 = pd.read_parquet(data_path2).head(10000)
merged_df = pd.concat([df1, df2], ignore_index=True)
# Create a new AIIAConfig instance
config = AIIAConfig(
model_name="AIIA-Base-512x20k",
)
# Initialize the base model
model = AIIABase(config)
# Create dataset loader with merged data
aiia_loader = AIIADataLoader(merged_df, column="image_bytes", batch_size=32)
aiia_loader = AIIADataLoader(merged_df, column="image_bytes", batch_size=32, pretraining=True)
# Access the train and validation loaders
train_loader = aiia_loader.train_loader
val_loader = aiia_loader.val_loader
# Initialize loss functions and optimizer
criterion_denoise = nn.MSELoss()
criterion_rotate = nn.CrossEntropyLoss()
@ -46,74 +35,61 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
print(f"\nEpoch {epoch+1}/{num_epochs}")
print("-" * 20)
# Training phase
model.train()
total_train_loss = 0.0
denoise_losses = []
rotate_losses = []
for batch in train_loader:
images, targets, tasks = zip(*batch)
if device == "cuda":
images = [img.cuda() for img in images]
targets = [t.cuda() for t in targets]
noisy_imgs, targets, tasks = batch
noisy_imgs = noisy_imgs.to(device)
targets = targets.to(device)
optimizer.zero_grad()
# Process each sample individually since tasks can vary
outputs = []
total_loss = 0.0
for i, (image, target, task) in enumerate(zip(images, targets, tasks)):
output = model(image.unsqueeze(0))
outputs = model(noisy_imgs)
task_losses = []
for i, task in enumerate(tasks):
if task == 'denoise':
loss = criterion_denoise(output.squeeze(), target)
elif task == 'rotate':
loss = criterion_rotate(output.view(-1, len(set(outputs))), target)
total_loss += loss
outputs.append(output)
avg_loss = total_loss / len(images)
avg_loss.backward()
loss = criterion_denoise(outputs[i], targets[i])
denoise_losses.append(loss.item())
else:
loss = criterion_rotate(outputs[i].unsqueeze(0), targets[i].unsqueeze(0))
rotate_losses.append(loss.item())
task_losses.append(loss)
batch_loss = sum(task_losses) / len(task_losses)
batch_loss.backward()
optimizer.step()
total_train_loss += avg_loss.item()
# Separate losses for reporting (you'd need to track this based on tasks)
total_train_loss += batch_loss.item()
avg_total_train_loss = total_train_loss / len(train_loader)
print(f"Training Loss: {avg_total_train_loss:.4f}")
# Validation phase
model.eval()
with torch.no_grad():
val_losses = []
for batch in val_loader:
images, targets, tasks = zip(*batch)
noisy_imgs, targets, tasks = batch
if device == "cuda":
images = [img.cuda() for img in images]
targets = [t.cuda() for t in targets]
outputs = []
total_loss = 0.0
for i, (image, target, task) in enumerate(zip(images, targets, tasks)):
output = model(image.unsqueeze(0))
noisy_imgs = noisy_imgs.to(device)
targets = targets.to(device)
outputs = model(noisy_imgs)
task_losses = []
for i, task in enumerate(tasks):
if task == 'denoise':
loss = criterion_denoise(output.squeeze(), target)
elif task == 'rotate':
loss = criterion_rotate(output.view(-1, len(set(outputs))), target)
total_loss += loss
outputs.append(output)
avg_val_loss = total_loss / len(images)
val_losses.append(avg_val_loss.item())
loss = criterion_denoise(outputs[i], targets[i])
else:
loss = criterion_rotate(outputs[i].unsqueeze(0), targets[i].unsqueeze(0))
task_losses.append(loss)
batch_loss = sum(task_losses) / len(task_losses)
val_losses.append(batch_loss.item())
avg_val_loss = sum(val_losses) / len(val_loader)
print(f"Validation Loss: {avg_val_loss:.4f}")
# Save the best model
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
model.save("BASEv0.1")
@ -122,4 +98,4 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
if __name__ == "__main__":
data_path1 = "/root/training_data/vision-dataset/images_checkpoint.parquet"
data_path2 = "/root/training_data/vision-dataset/vec_images_dataset.parquet"
pretrain_model(data_path1, data_path2, num_epochs=8)
pretrain_model(data_path1, data_path2, num_epochs=3)