AIIA/src/aiia/pretrain/pretrainer.py

230 lines
8.1 KiB
Python

import torch
from torch import nn
import csv
import pandas as pd
from tqdm import tqdm
from ..model.Model import AIIA
from ..model.config import AIIAConfig
from ..data.DataLoader import AIIADataLoader
class ProjectionHead(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.conv_denoise = nn.Conv2d(hidden_size, 3, kernel_size=1)
self.conv_rotate = nn.Conv2d(hidden_size, 4, kernel_size=1) # 4 classes for 0, 90, 180, 270 degrees
def forward(self, x, task='denoise'):
if task == 'denoise':
return self.conv_denoise(x)
else:
return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task
class Pretrainer:
def __init__(self, model: AIIA, learning_rate=1e-4, config: AIIAConfig=None):
"""
Initialize the pretrainer with a model.
Args:
model (AIIA): The model instance to pretrain
learning_rate (float): Learning rate for optimization
config (dict): Model configuration containing hidden_size
"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = model.to(self.device)
hidden_size = config.hidden_size
self.projection_head = ProjectionHead(hidden_size).to(self.device)
self.optimizer = torch.optim.AdamW(
list(self.model.parameters()) + list(self.projection_head.parameters()),
lr=learning_rate
)
self.train_losses = []
self.val_losses = []
@staticmethod
def safe_collate(batch):
"""Safely collate batch data handling both denoise and rotate tasks."""
denoise_batch = []
rotate_batch = []
for sample in batch:
try:
noisy_img, target, task = sample
if task == 'denoise':
denoise_batch.append({
'image': noisy_img,
'target': target,
'task': task
})
else: # rotate task
rotate_batch.append({
'image': noisy_img,
'target': target,
'task': task
})
except Exception as e:
print(f"Skipping sample due to error: {e}")
continue
if not denoise_batch and not rotate_batch:
return None
batch_data = {
'denoise': None,
'rotate': None
}
if denoise_batch:
images = torch.stack([x['image'] for x in denoise_batch])
targets = torch.stack([x['target'] for x in denoise_batch])
batch_data['denoise'] = (images, targets)
if rotate_batch:
images = torch.stack([x['image'] for x in rotate_batch])
targets = torch.stack([x['target'] for x in rotate_batch])
batch_data['rotate'] = (images, targets)
return batch_data
def _process_batch(self, batch_data, criterion_denoise, criterion_rotate, training=True):
"""Process a single batch of data."""
batch_loss = 0
if batch_data['denoise'] is not None:
noisy_imgs, targets = batch_data['denoise']
noisy_imgs = noisy_imgs.to(self.device)
targets = targets.to(self.device)
features = self.model(noisy_imgs)
outputs = self.projection_head(features, task='denoise')
loss = criterion_denoise(outputs, targets)
batch_loss += loss
if batch_data['rotate'] is not None:
imgs, targets = batch_data['rotate']
imgs = imgs.to(self.device)
targets = targets.long().to(self.device)
features = self.model(imgs)
outputs = self.projection_head(features, task='rotate')
loss = criterion_rotate(outputs, targets)
batch_loss += loss
return batch_loss
def train(self, dataset_paths, column="image_bytes", num_epochs=3, batch_size=2, sample_size=10000):
"""
Train the model using multiple specified datasets.
Args:
dataset_paths (list): List of paths to parquet datasets
num_epochs (int): Number of training epochs
batch_size (int): Batch size for training
sample_size (int): Number of samples to use from each dataset
"""
if not dataset_paths:
raise ValueError("No dataset paths provided")
# Read and merge all datasets
dataframes = []
for path in dataset_paths:
try:
df = pd.read_parquet(path).head(sample_size)
dataframes.append(df)
except Exception as e:
print(f"Error loading dataset {path}: {e}")
if not dataframes:
raise ValueError("No valid datasets could be loaded")
merged_df = pd.concat(dataframes, ignore_index=True)
# Initialize data loader
aiia_loader = AIIADataLoader(
merged_df,
column=column,
batch_size=batch_size,
pretraining=True,
collate_fn=self.safe_collate
)
criterion_denoise = nn.MSELoss()
criterion_rotate = nn.CrossEntropyLoss()
best_val_loss = float('inf')
for epoch in range(num_epochs):
print(f"\nEpoch {epoch+1}/{num_epochs}")
print("-" * 20)
# Training phase
self.model.train()
self.projection_head.train()
total_train_loss = 0.0
batch_count = 0
for batch_data in tqdm(aiia_loader.train_loader):
if batch_data is None:
continue
self.optimizer.zero_grad()
batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate)
if batch_loss > 0:
batch_loss.backward()
self.optimizer.step()
total_train_loss += batch_loss.item()
batch_count += 1
avg_train_loss = total_train_loss / max(batch_count, 1)
self.train_losses.append(avg_train_loss)
print(f"Training Loss: {avg_train_loss:.4f}")
# Validation phase
self.model.eval()
self.projection_head.eval()
val_loss = self._validate(aiia_loader.val_loader, criterion_denoise, criterion_rotate)
if val_loss < best_val_loss:
best_val_loss = val_loss
self.model.save("AIIA-base-512")
print("Best model saved!")
self.save_losses('losses.csv')
def _validate(self, val_loader, criterion_denoise, criterion_rotate):
"""Perform validation and return average validation loss."""
val_loss = 0.0
val_batch_count = 0
with torch.no_grad():
for batch_data in val_loader:
if batch_data is None:
continue
batch_loss = self._process_batch(
batch_data, criterion_denoise, criterion_rotate, training=False
)
if batch_loss > 0:
val_loss += batch_loss.item()
val_batch_count += 1
avg_val_loss = val_loss / max(val_batch_count, 1)
self.val_losses.append(avg_val_loss)
print(f"Validation Loss: {avg_val_loss:.4f}")
return avg_val_loss
def save_losses(self, csv_file):
"""Save training and validation losses to a CSV file."""
data = list(zip(
range(1, len(self.train_losses) + 1),
self.train_losses,
self.val_losses
))
with open(csv_file, mode='w', newline='') as file:
writer = csv.writer(file)
writer.writerow(['Epoch', 'Train Loss', 'Validation Loss'])
writer.writerows(data)
print(f"Loss data has been written to {csv_file}")