updated dataloader to work with tupels

This commit is contained in:
Falko Victor Habel 2025-01-26 22:48:29 +01:00
parent 3f6e6514a9
commit 7c4aef0978
2 changed files with 61 additions and 77 deletions

View File

@ -106,10 +106,11 @@ class JPGImageLoader:
print(f"Skipped {self.skipped_count} images due to errors.") print(f"Skipped {self.skipped_count} images due to errors.")
class AIIADataLoader: class AIIADataLoader:
def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path", label_column=None, **dataloader_kwargs): def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path", label_column=None, pretraining=False, **dataloader_kwargs):
self.batch_size = batch_size self.batch_size = batch_size
self.val_split = val_split self.val_split = val_split
self.seed = seed self.seed = seed
self.pretraining = pretraining
random.seed(seed) random.seed(seed)
sample_value = dataset[column].iloc[0] sample_value = dataset[column].iloc[0]
@ -134,7 +135,12 @@ class AIIADataLoader:
for idx in range(len(dataset)): for idx in range(len(dataset)):
item = self.loader.get_item(idx) item = self.loader.get_item(idx)
if item is not None: if item is not None:
self.items.append(item) if self.pretraining:
img = item[0] if isinstance(item, tuple) else item
self.items.append((img, 'denoise', img))
self.items.append((img, 'rotate', 0))
else:
self.items.append(item)
if not self.items: if not self.items:
raise ValueError("No valid items were loaded from the dataset") raise ValueError("No valid items were loaded from the dataset")
@ -163,12 +169,14 @@ class AIIADataLoader:
def _create_subset(self, indices): def _create_subset(self, indices):
subset_items = [self.items[i] for i in indices] subset_items = [self.items[i] for i in indices]
return AIIADataset(subset_items) return AIIADataset(subset_items, pretraining=self.pretraining)
class AIIADataset(torch.utils.data.Dataset): class AIIADataset(torch.utils.data.Dataset):
def __init__(self, items): def __init__(self, items, pretraining=False):
self.items = items self.items = items
self.pretraining = pretraining
self.transform = transforms.Compose([ self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor() transforms.ToTensor()
]) ])
@ -177,29 +185,29 @@ class AIIADataset(torch.utils.data.Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
item = self.items[idx] item = self.items[idx]
if isinstance(item, tuple) and len(item) == 2:
image, label = item if self.pretraining:
image = self.transform(image)
return (image, label)
elif isinstance(item, tuple) and len(item) == 3:
image, task, label = item image, task, label = item
image = self.transform(image) image = self.transform(image)
if task == 'denoise': if task == 'denoise':
noise_std = 0.1 noise_std = 0.1
noisy_img = image + torch.randn_like(image) * noise_std noisy_img = image + torch.randn_like(image) * noise_std
target = image target = image.clone()
return (noisy_img, target, task) return noisy_img, target, task
elif task == 'rotate': elif task == 'rotate':
angles = [0, 90, 180, 270] angles = [0, 90, 180, 270]
angle = random.choice(angles) angle = random.choice(angles)
rotated_img = transforms.functional.rotate(image, angle) rotated_img = transforms.functional.rotate(image, angle)
target = torch.tensor(angle).long() target = torch.tensor(angle / 90).long()
return (rotated_img, target, task) return rotated_img, target, task
else:
raise ValueError(f"Unknown task: {task}")
else: else:
if isinstance(item, Image.Image): if isinstance(item, tuple) and len(item) == 2:
return self.transform(item) image, label = item
image = self.transform(image)
return image, label
else: else:
raise ValueError("Invalid item format.") if isinstance(item, Image.Image):
return self.transform(item)
else:
return self.transform(item[0])

View File

@ -1,37 +1,26 @@
import torch import torch
from torch import nn from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import random
import pandas as pd import pandas as pd
from aiia.model.config import AIIAConfig from aiia.model.config import AIIAConfig
from aiia.model import AIIABase from aiia.model import AIIABase
from aiia.data.DataLoader import AIIADataLoader from aiia.data.DataLoader import AIIADataLoader
def pretrain_model(data_path1, data_path2, num_epochs=3): def pretrain_model(data_path1, data_path2, num_epochs=3):
# Merge the two parquet files
df1 = pd.read_parquet(data_path1).head(10000) df1 = pd.read_parquet(data_path1).head(10000)
df2 = pd.read_parquet(data_path2).head(10000) df2 = pd.read_parquet(data_path2).head(10000)
merged_df = pd.concat([df1, df2], ignore_index=True) merged_df = pd.concat([df1, df2], ignore_index=True)
# Create a new AIIAConfig instance
config = AIIAConfig( config = AIIAConfig(
model_name="AIIA-Base-512x20k", model_name="AIIA-Base-512x20k",
) )
# Initialize the base model
model = AIIABase(config) model = AIIABase(config)
# Create dataset loader with merged data aiia_loader = AIIADataLoader(merged_df, column="image_bytes", batch_size=32, pretraining=True)
aiia_loader = AIIADataLoader(merged_df, column="image_bytes", batch_size=32)
# Access the train and validation loaders
train_loader = aiia_loader.train_loader train_loader = aiia_loader.train_loader
val_loader = aiia_loader.val_loader val_loader = aiia_loader.val_loader
# Initialize loss functions and optimizer
criterion_denoise = nn.MSELoss() criterion_denoise = nn.MSELoss()
criterion_rotate = nn.CrossEntropyLoss() criterion_rotate = nn.CrossEntropyLoss()
@ -46,74 +35,61 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
print(f"\nEpoch {epoch+1}/{num_epochs}") print(f"\nEpoch {epoch+1}/{num_epochs}")
print("-" * 20) print("-" * 20)
# Training phase
model.train() model.train()
total_train_loss = 0.0 total_train_loss = 0.0
denoise_losses = []
rotate_losses = []
for batch in train_loader: for batch in train_loader:
images, targets, tasks = zip(*batch) noisy_imgs, targets, tasks = batch
if device == "cuda":
images = [img.cuda() for img in images]
targets = [t.cuda() for t in targets]
noisy_imgs = noisy_imgs.to(device)
targets = targets.to(device)
optimizer.zero_grad() optimizer.zero_grad()
# Process each sample individually since tasks can vary outputs = model(noisy_imgs)
outputs = [] task_losses = []
total_loss = 0.0 for i, task in enumerate(tasks):
for i, (image, target, task) in enumerate(zip(images, targets, tasks)):
output = model(image.unsqueeze(0))
if task == 'denoise': if task == 'denoise':
loss = criterion_denoise(output.squeeze(), target) loss = criterion_denoise(outputs[i], targets[i])
elif task == 'rotate': denoise_losses.append(loss.item())
loss = criterion_rotate(output.view(-1, len(set(outputs))), target) else:
loss = criterion_rotate(outputs[i].unsqueeze(0), targets[i].unsqueeze(0))
total_loss += loss rotate_losses.append(loss.item())
outputs.append(output) task_losses.append(loss)
avg_loss = total_loss / len(images) batch_loss = sum(task_losses) / len(task_losses)
avg_loss.backward() batch_loss.backward()
optimizer.step() optimizer.step()
total_train_loss += avg_loss.item() total_train_loss += batch_loss.item()
# Separate losses for reporting (you'd need to track this based on tasks)
avg_total_train_loss = total_train_loss / len(train_loader) avg_total_train_loss = total_train_loss / len(train_loader)
print(f"Training Loss: {avg_total_train_loss:.4f}") print(f"Training Loss: {avg_total_train_loss:.4f}")
# Validation phase
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
val_losses = [] val_losses = []
for batch in val_loader: for batch in val_loader:
images, targets, tasks = zip(*batch) noisy_imgs, targets, tasks = batch
if device == "cuda": noisy_imgs = noisy_imgs.to(device)
images = [img.cuda() for img in images] targets = targets.to(device)
targets = [t.cuda() for t in targets]
outputs = model(noisy_imgs)
outputs = []
total_loss = 0.0 task_losses = []
for i, (image, target, task) in enumerate(zip(images, targets, tasks)): for i, task in enumerate(tasks):
output = model(image.unsqueeze(0))
if task == 'denoise': if task == 'denoise':
loss = criterion_denoise(output.squeeze(), target) loss = criterion_denoise(outputs[i], targets[i])
elif task == 'rotate': else:
loss = criterion_rotate(output.view(-1, len(set(outputs))), target) loss = criterion_rotate(outputs[i].unsqueeze(0), targets[i].unsqueeze(0))
task_losses.append(loss)
total_loss += loss
outputs.append(output) batch_loss = sum(task_losses) / len(task_losses)
val_losses.append(batch_loss.item())
avg_val_loss = total_loss / len(images)
val_losses.append(avg_val_loss.item())
avg_val_loss = sum(val_losses) / len(val_loader) avg_val_loss = sum(val_losses) / len(val_loader)
print(f"Validation Loss: {avg_val_loss:.4f}") print(f"Validation Loss: {avg_val_loss:.4f}")
# Save the best model
if avg_val_loss < best_val_loss: if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss best_val_loss = avg_val_loss
model.save("BASEv0.1") model.save("BASEv0.1")
@ -122,4 +98,4 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
if __name__ == "__main__": if __name__ == "__main__":
data_path1 = "/root/training_data/vision-dataset/images_checkpoint.parquet" data_path1 = "/root/training_data/vision-dataset/images_checkpoint.parquet"
data_path2 = "/root/training_data/vision-dataset/vec_images_dataset.parquet" data_path2 = "/root/training_data/vision-dataset/vec_images_dataset.parquet"
pretrain_model(data_path1, data_path2, num_epochs=8) pretrain_model(data_path1, data_path2, num_epochs=3)