updated dataloader to work with tupels
This commit is contained in:
parent
3f6e6514a9
commit
7c4aef0978
|
@ -106,10 +106,11 @@ class JPGImageLoader:
|
|||
print(f"Skipped {self.skipped_count} images due to errors.")
|
||||
|
||||
class AIIADataLoader:
|
||||
def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path", label_column=None, **dataloader_kwargs):
|
||||
def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path", label_column=None, pretraining=False, **dataloader_kwargs):
|
||||
self.batch_size = batch_size
|
||||
self.val_split = val_split
|
||||
self.seed = seed
|
||||
self.pretraining = pretraining
|
||||
random.seed(seed)
|
||||
|
||||
sample_value = dataset[column].iloc[0]
|
||||
|
@ -134,6 +135,11 @@ class AIIADataLoader:
|
|||
for idx in range(len(dataset)):
|
||||
item = self.loader.get_item(idx)
|
||||
if item is not None:
|
||||
if self.pretraining:
|
||||
img = item[0] if isinstance(item, tuple) else item
|
||||
self.items.append((img, 'denoise', img))
|
||||
self.items.append((img, 'rotate', 0))
|
||||
else:
|
||||
self.items.append(item)
|
||||
|
||||
if not self.items:
|
||||
|
@ -163,12 +169,14 @@ class AIIADataLoader:
|
|||
|
||||
def _create_subset(self, indices):
|
||||
subset_items = [self.items[i] for i in indices]
|
||||
return AIIADataset(subset_items)
|
||||
return AIIADataset(subset_items, pretraining=self.pretraining)
|
||||
|
||||
class AIIADataset(torch.utils.data.Dataset):
|
||||
def __init__(self, items):
|
||||
def __init__(self, items, pretraining=False):
|
||||
self.items = items
|
||||
self.pretraining = pretraining
|
||||
self.transform = transforms.Compose([
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.ToTensor()
|
||||
])
|
||||
|
||||
|
@ -177,29 +185,29 @@ class AIIADataset(torch.utils.data.Dataset):
|
|||
|
||||
def __getitem__(self, idx):
|
||||
item = self.items[idx]
|
||||
if isinstance(item, tuple) and len(item) == 2:
|
||||
image, label = item
|
||||
image = self.transform(image)
|
||||
return (image, label)
|
||||
elif isinstance(item, tuple) and len(item) == 3:
|
||||
|
||||
if self.pretraining:
|
||||
image, task, label = item
|
||||
image = self.transform(image)
|
||||
|
||||
if task == 'denoise':
|
||||
noise_std = 0.1
|
||||
noisy_img = image + torch.randn_like(image) * noise_std
|
||||
target = image
|
||||
return (noisy_img, target, task)
|
||||
target = image.clone()
|
||||
return noisy_img, target, task
|
||||
elif task == 'rotate':
|
||||
angles = [0, 90, 180, 270]
|
||||
angle = random.choice(angles)
|
||||
rotated_img = transforms.functional.rotate(image, angle)
|
||||
target = torch.tensor(angle).long()
|
||||
return (rotated_img, target, task)
|
||||
target = torch.tensor(angle / 90).long()
|
||||
return rotated_img, target, task
|
||||
else:
|
||||
raise ValueError(f"Unknown task: {task}")
|
||||
if isinstance(item, tuple) and len(item) == 2:
|
||||
image, label = item
|
||||
image = self.transform(image)
|
||||
return image, label
|
||||
else:
|
||||
if isinstance(item, Image.Image):
|
||||
return self.transform(item)
|
||||
else:
|
||||
raise ValueError("Invalid item format.")
|
||||
return self.transform(item[0])
|
||||
|
|
|
@ -1,37 +1,26 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision import transforms
|
||||
from PIL import Image
|
||||
import os
|
||||
import random
|
||||
import pandas as pd
|
||||
from aiia.model.config import AIIAConfig
|
||||
from aiia.model import AIIABase
|
||||
from aiia.data.DataLoader import AIIADataLoader
|
||||
|
||||
def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||
# Merge the two parquet files
|
||||
df1 = pd.read_parquet(data_path1).head(10000)
|
||||
df2 = pd.read_parquet(data_path2).head(10000)
|
||||
merged_df = pd.concat([df1, df2], ignore_index=True)
|
||||
|
||||
# Create a new AIIAConfig instance
|
||||
config = AIIAConfig(
|
||||
model_name="AIIA-Base-512x20k",
|
||||
)
|
||||
|
||||
# Initialize the base model
|
||||
model = AIIABase(config)
|
||||
|
||||
# Create dataset loader with merged data
|
||||
aiia_loader = AIIADataLoader(merged_df, column="image_bytes", batch_size=32)
|
||||
aiia_loader = AIIADataLoader(merged_df, column="image_bytes", batch_size=32, pretraining=True)
|
||||
|
||||
# Access the train and validation loaders
|
||||
train_loader = aiia_loader.train_loader
|
||||
val_loader = aiia_loader.val_loader
|
||||
|
||||
# Initialize loss functions and optimizer
|
||||
criterion_denoise = nn.MSELoss()
|
||||
criterion_rotate = nn.CrossEntropyLoss()
|
||||
|
||||
|
@ -46,74 +35,61 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
|||
print(f"\nEpoch {epoch+1}/{num_epochs}")
|
||||
print("-" * 20)
|
||||
|
||||
# Training phase
|
||||
model.train()
|
||||
total_train_loss = 0.0
|
||||
denoise_losses = []
|
||||
rotate_losses = []
|
||||
|
||||
for batch in train_loader:
|
||||
images, targets, tasks = zip(*batch)
|
||||
|
||||
if device == "cuda":
|
||||
images = [img.cuda() for img in images]
|
||||
targets = [t.cuda() for t in targets]
|
||||
noisy_imgs, targets, tasks = batch
|
||||
|
||||
noisy_imgs = noisy_imgs.to(device)
|
||||
targets = targets.to(device)
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Process each sample individually since tasks can vary
|
||||
outputs = []
|
||||
total_loss = 0.0
|
||||
for i, (image, target, task) in enumerate(zip(images, targets, tasks)):
|
||||
output = model(image.unsqueeze(0))
|
||||
|
||||
outputs = model(noisy_imgs)
|
||||
task_losses = []
|
||||
for i, task in enumerate(tasks):
|
||||
if task == 'denoise':
|
||||
loss = criterion_denoise(output.squeeze(), target)
|
||||
elif task == 'rotate':
|
||||
loss = criterion_rotate(output.view(-1, len(set(outputs))), target)
|
||||
loss = criterion_denoise(outputs[i], targets[i])
|
||||
denoise_losses.append(loss.item())
|
||||
else:
|
||||
loss = criterion_rotate(outputs[i].unsqueeze(0), targets[i].unsqueeze(0))
|
||||
rotate_losses.append(loss.item())
|
||||
task_losses.append(loss)
|
||||
|
||||
total_loss += loss
|
||||
outputs.append(output)
|
||||
|
||||
avg_loss = total_loss / len(images)
|
||||
avg_loss.backward()
|
||||
batch_loss = sum(task_losses) / len(task_losses)
|
||||
batch_loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_train_loss += avg_loss.item()
|
||||
# Separate losses for reporting (you'd need to track this based on tasks)
|
||||
|
||||
total_train_loss += batch_loss.item()
|
||||
avg_total_train_loss = total_train_loss / len(train_loader)
|
||||
print(f"Training Loss: {avg_total_train_loss:.4f}")
|
||||
|
||||
# Validation phase
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
val_losses = []
|
||||
for batch in val_loader:
|
||||
images, targets, tasks = zip(*batch)
|
||||
noisy_imgs, targets, tasks = batch
|
||||
|
||||
if device == "cuda":
|
||||
images = [img.cuda() for img in images]
|
||||
targets = [t.cuda() for t in targets]
|
||||
noisy_imgs = noisy_imgs.to(device)
|
||||
targets = targets.to(device)
|
||||
|
||||
outputs = []
|
||||
total_loss = 0.0
|
||||
for i, (image, target, task) in enumerate(zip(images, targets, tasks)):
|
||||
output = model(image.unsqueeze(0))
|
||||
outputs = model(noisy_imgs)
|
||||
|
||||
task_losses = []
|
||||
for i, task in enumerate(tasks):
|
||||
if task == 'denoise':
|
||||
loss = criterion_denoise(output.squeeze(), target)
|
||||
elif task == 'rotate':
|
||||
loss = criterion_rotate(output.view(-1, len(set(outputs))), target)
|
||||
|
||||
total_loss += loss
|
||||
outputs.append(output)
|
||||
|
||||
avg_val_loss = total_loss / len(images)
|
||||
val_losses.append(avg_val_loss.item())
|
||||
loss = criterion_denoise(outputs[i], targets[i])
|
||||
else:
|
||||
loss = criterion_rotate(outputs[i].unsqueeze(0), targets[i].unsqueeze(0))
|
||||
task_losses.append(loss)
|
||||
|
||||
batch_loss = sum(task_losses) / len(task_losses)
|
||||
val_losses.append(batch_loss.item())
|
||||
avg_val_loss = sum(val_losses) / len(val_loader)
|
||||
print(f"Validation Loss: {avg_val_loss:.4f}")
|
||||
|
||||
# Save the best model
|
||||
if avg_val_loss < best_val_loss:
|
||||
best_val_loss = avg_val_loss
|
||||
model.save("BASEv0.1")
|
||||
|
@ -122,4 +98,4 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
|||
if __name__ == "__main__":
|
||||
data_path1 = "/root/training_data/vision-dataset/images_checkpoint.parquet"
|
||||
data_path2 = "/root/training_data/vision-dataset/vec_images_dataset.parquet"
|
||||
pretrain_model(data_path1, data_path2, num_epochs=8)
|
||||
pretrain_model(data_path1, data_path2, num_epochs=3)
|
Loading…
Reference in New Issue