AIIA/src/aiia/pretrain/pretrainer.py

409 lines
16 KiB
Python

import torch
from torch import nn
import csv
import datetime
import time
import pandas as pd
from tqdm import tqdm
from transformers import PreTrainedModel
from ..model.config import AIIAConfig
from ..data.DataLoader import AIIADataLoader
import os
class ProjectionHead(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.conv_denoise = nn.Conv2d(hidden_size, 3, kernel_size=1)
self.conv_rotate = nn.Conv2d(hidden_size, 4, kernel_size=1) # 4 classes for 0, 90, 180, 270 degrees
def forward(self, x, task='denoise'):
if task == 'denoise':
return self.conv_denoise(x)
else:
return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task
class Pretrainer:
def __init__(self, model: PreTrainedModel, learning_rate=1e-4, config: AIIAConfig=None):
"""
Initialize the pretrainer with a model.
Args:
model (AIIA): The model instance to pretrain
learning_rate (float): Learning rate for optimization
config (dict): Model configuration containing hidden_size
"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = model.to(self.device)
hidden_size = config.hidden_size
self.projection_head = ProjectionHead(hidden_size).to(self.device)
self.optimizer = torch.optim.AdamW(
list(self.model.parameters()) + list(self.projection_head.parameters()),
lr=learning_rate
)
self.train_losses = []
self.val_losses = []
self.checkpoint_dir = None # Initialize checkpoint_dir
self.current_epoch = 0 # Add current_epoch tracking
@staticmethod
def safe_collate(batch):
"""Safely collate batch data handling both denoise and rotate tasks."""
denoise_batch = []
rotate_batch = []
for sample in batch:
try:
noisy_img, target, task = sample
if task == 'denoise':
denoise_batch.append({
'image': noisy_img,
'target': target,
'task': task
})
else: # rotate task
rotate_batch.append({
'image': noisy_img,
'target': target,
'task': task
})
except Exception as e:
print(f"Skipping sample due to error: {e}")
continue
if not denoise_batch and not rotate_batch:
return None
batch_data = {
'denoise': None,
'rotate': None
}
if denoise_batch:
images = torch.stack([x['image'] for x in denoise_batch])
targets = torch.stack([x['target'] for x in denoise_batch])
batch_data['denoise'] = (images, targets)
if rotate_batch:
images = torch.stack([x['image'] for x in rotate_batch])
targets = torch.stack([x['target'] for x in rotate_batch])
batch_data['rotate'] = (images, targets)
return batch_data
def _process_batch(self, batch_data, criterion_denoise, criterion_rotate, training=True):
"""Process a single batch of data."""
batch_loss = 0
if batch_data['denoise'] is not None:
noisy_imgs, targets = batch_data['denoise']
noisy_imgs = noisy_imgs.to(self.device)
targets = targets.to(self.device)
features = self.model(noisy_imgs)
outputs = self.projection_head(features, task='denoise')
loss = criterion_denoise(outputs, targets)
batch_loss += loss
if batch_data['rotate'] is not None:
imgs, targets = batch_data['rotate']
imgs = imgs.to(self.device)
targets = targets.long().to(self.device)
features = self.model(imgs)
outputs = self.projection_head(features, task='rotate')
loss = criterion_rotate(outputs, targets)
batch_loss += loss
return batch_loss
def _save_checkpoint(self, checkpoint_dir, epoch, batch_count, checkpoint_name):
"""Save a model checkpoint.
Args:
checkpoint_dir (str): Directory to save the checkpoint
epoch (int): Current epoch number
batch_count (int): Current batch count
checkpoint_name (str): Name for the checkpoint file
Returns:
str: Path to the saved checkpoint
"""
checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name)
checkpoint_data = {
'epoch': epoch + 1,
'batch': batch_count,
'model_state_dict': self.model.state_dict(),
'projection_head_state_dict': self.projection_head.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'train_losses': self.train_losses,
'val_losses': self.val_losses,
}
torch.save(checkpoint_data, checkpoint_path)
return checkpoint_path
def load_checkpoint(self, checkpoint_dir, specific_checkpoint=None):
"""Check for checkpoints and load if available.
Args:
checkpoint_dir (str): Directory where checkpoints are stored
specific_checkpoint (str, optional): Specific checkpoint file to load. If None, loads the most recent.
Returns:
tuple: (loaded_epoch, loaded_batch) if checkpoint was loaded, None otherwise
"""
# Create checkpoint directory if it doesn't exist
os.makedirs(checkpoint_dir, exist_ok=True)
# If a specific checkpoint is requested
if specific_checkpoint:
checkpoint_path = os.path.join(checkpoint_dir, specific_checkpoint)
if os.path.exists(checkpoint_path):
return self._load_checkpoint_file(checkpoint_path)
else:
print(f"Specified checkpoint {specific_checkpoint} not found.")
return None
# Find all checkpoint files
checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.startswith("checkpoint_") and f.endswith(".pt")]
if not checkpoint_files:
print("No checkpoints found in directory.")
return None
# Find the most recent checkpoint
checkpoint_files.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True)
most_recent = checkpoint_files[0]
checkpoint_path = os.path.join(checkpoint_dir, most_recent)
return self._load_checkpoint_file(checkpoint_path)
def _load_checkpoint_file(self, checkpoint_path):
"""Load a specific checkpoint file.
Args:
checkpoint_path (str): Path to the checkpoint file
Returns:
tuple: (loaded_epoch, loaded_batch) if checkpoint was loaded, None otherwise
"""
try:
checkpoint = torch.load(checkpoint_path, map_location=self.device)
# Load model state
self.model.load_state_dict(checkpoint['model_state_dict'])
# Load projection head state
self.projection_head.load_state_dict(checkpoint['projection_head_state_dict'])
# Load optimizer state
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# Load loss history
self.train_losses = checkpoint.get('train_losses', [])
self.val_losses = checkpoint.get('val_losses', [])
loaded_epoch = checkpoint['epoch']
loaded_batch = checkpoint['batch']
print(f"Checkpoint loaded from {checkpoint_path}")
print(f"Resuming from epoch {loaded_epoch}, batch {loaded_batch}")
return loaded_epoch, loaded_batch
except Exception as e:
print(f"Error loading checkpoint: {e}")
return None
def train(self, dataset_paths, output_path="AIIA", column="image_bytes",
num_epochs=3, batch_size=2, sample_size=10000, checkpoint_dir=None):
"""Train the model using multiple specified datasets with checkpoint resumption support."""
if not dataset_paths:
raise ValueError("No dataset paths provided")
self.checkpoint_dir = checkpoint_dir # Set checkpoint_dir class variable
self._initialize_checkpoint_variables()
start_epoch, start_batch, resume_training = self._load_checkpoints(checkpoint_dir)
dataframes = self._load_and_merge_datasets(dataset_paths, sample_size)
aiia_loader = self._initialize_data_loader(dataframes, column, batch_size)
criterion_denoise, criterion_rotate, best_val_loss = self._initialize_loss_functions()
for epoch in range(start_epoch, num_epochs):
self.current_epoch = epoch # Update current_epoch
print(f"\nEpoch {epoch+1}/{num_epochs}")
print("-" * 20)
total_train_loss, batch_count = self._training_phase(aiia_loader.train_loader,
start_batch if (epoch == start_epoch and resume_training) else 0,
criterion_denoise,
criterion_rotate)
avg_train_loss = total_train_loss / max(batch_count, 1)
self.train_losses.append(avg_train_loss)
print(f"Training Loss: {avg_train_loss:.4f}")
val_loss = self._validation_phase(aiia_loader.val_loader, criterion_denoise, criterion_rotate)
if val_loss < best_val_loss:
best_val_loss = val_loss
self.model.save_pretrained(output_path)
print("Best model saved!")
losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv')
self.save_losses(losses_path)
def _initialize_checkpoint_variables(self):
"""Initialize checkpoint tracking variables."""
self.last_checkpoint_time = time.time()
self.checkpoint_interval = 2 * 60 * 60 # 2 hours in seconds
self.last_22_date = None
self.recent_checkpoints = []
def _load_checkpoints(self, checkpoint_dir):
"""Load checkpoints and return start epoch, batch, and resumption flag."""
start_epoch = 0
start_batch = 0
resume_training = False
if checkpoint_dir is not None:
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_info = self.load_checkpoint(checkpoint_dir)
if checkpoint_info:
start_epoch, start_batch = checkpoint_info
resume_training = True
# Adjust epoch to be 0-indexed for the loop
start_epoch -= 1
return start_epoch, start_batch, resume_training
def _load_and_merge_datasets(self, dataset_paths, sample_size):
"""Load and merge datasets."""
dataframes = []
for path in dataset_paths:
try:
df = pd.read_parquet(path).head(sample_size)
dataframes.append(df)
except Exception as e:
print(f"Error loading dataset {path}: {e}")
if not dataframes:
raise ValueError("No valid datasets could be loaded")
return pd.concat(dataframes, ignore_index=True)
def _initialize_data_loader(self, merged_df, column, batch_size):
"""Initialize the data loader."""
return AIIADataLoader(
merged_df,
column=column,
batch_size=batch_size,
pretraining=True,
collate_fn=self.safe_collate
)
def _initialize_loss_functions(self):
"""Initialize loss functions and tracking variables."""
criterion_denoise = nn.MSELoss()
criterion_rotate = nn.CrossEntropyLoss()
best_val_loss = float('inf')
return criterion_denoise, criterion_rotate, best_val_loss
def _training_phase(self, train_loader, skip_batches, criterion_denoise, criterion_rotate):
"""Handle the training phase."""
self.model.train()
self.projection_head.train()
total_train_loss = 0.0
batch_count = 0
train_batches = list(enumerate(train_loader))
for i, batch_data in tqdm(train_batches[skip_batches:],
initial=skip_batches,
total=len(train_batches)):
if batch_data is None:
continue
current_batch = i + 1
self._handle_checkpoints(current_batch)
self.optimizer.zero_grad()
batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate)
if batch_loss > 0:
batch_loss.backward()
self.optimizer.step()
total_train_loss += batch_loss.item()
batch_count += 1
return total_train_loss, batch_count
def _handle_checkpoints(self, current_batch):
"""Handle checkpoint saving logic."""
current_time = time.time()
current_dt = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2))) # German time
today = current_dt.date()
if self.checkpoint_dir and (current_time - self.last_checkpoint_time) >= self.checkpoint_interval:
checkpoint_name = f"checkpoint_epoch{self.current_epoch+1}_batch{current_batch}.pt"
checkpoint_path = self._save_checkpoint(self.checkpoint_dir, self.current_epoch, current_batch, checkpoint_name)
# Track and maintain only 3 recent checkpoints
self.recent_checkpoints.append(checkpoint_path)
if len(self.recent_checkpoints) > 3:
oldest = self.recent_checkpoints.pop(0)
if os.path.exists(oldest):
os.remove(oldest)
self.last_checkpoint_time = current_time
print(f"Checkpoint saved at {checkpoint_path}")
# Special 22:00 checkpoint (considering it's currently 10:15 PM)
is_22_oclock = current_dt.hour == 22 and current_dt.minute < 15
if self.checkpoint_dir and is_22_oclock and self.last_22_date != today:
checkpoint_name = f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt"
checkpoint_path = self._save_checkpoint(self.checkpoint_dir, self.current_epoch, current_batch, checkpoint_name)
self.last_22_date = today
print(f"22:00 Checkpoint saved at {checkpoint_path}")
def _validation_phase(self, val_loader, criterion_denoise, criterion_rotate):
"""Handle the validation phase."""
self.model.eval()
self.projection_head.eval()
return self._validate(val_loader, criterion_denoise, criterion_rotate)
def _validate(self, val_loader, criterion_denoise, criterion_rotate):
"""Perform validation and return average validation loss."""
val_loss = 0.0
val_batch_count = 0
with torch.no_grad():
for batch_data in val_loader:
if batch_data is None:
continue
batch_loss = self._process_batch(
batch_data, criterion_denoise, criterion_rotate, training=False
)
if batch_loss > 0:
val_loss += batch_loss.item()
val_batch_count += 1
avg_val_loss = val_loss / max(val_batch_count, 1)
self.val_losses.append(avg_val_loss)
print(f"Validation Loss: {avg_val_loss:.4f}")
return avg_val_loss
def save_losses(self, csv_file):
"""Save training and validation losses to a CSV file."""
data = list(zip(
range(1, len(self.train_losses) + 1),
self.train_losses,
self.val_losses
))
with open(csv_file, mode='w', newline='') as file:
writer = csv.writer(file)
writer.writerow(['Epoch', 'Train Loss', 'Validation Loss'])
writer.writerows(data)
print(f"Loss data has been written to {csv_file}")